Giter Club home page Giter Club logo

Comments (8)

SkafteNicki avatar SkafteNicki commented on May 27, 2024 1

@michael080808 thanks for providing this overview, it really helps me.
So it seems to me that the best cause of action is to take the implementation from https://github.com/dingkeyan93/IQA-optimization/blob/master/IQA_pytorch/CW_SSIM.py, make sure it works with newer versions of pytorch and then replace the backend to use https://github.com/LabForComputationalVision/plenoptic/blob/main/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py instead. I think that is definitely doable, but I have to look closer at the code and understand the metric a bit better.

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on May 27, 2024

Hi @michael080808, thanks for opening this issue. We would be more than welcome to receive a pull request with this metric (either a partial implementation or full), but I do not think anyone at the core team has bandwidth or the experience to implement such a complex metric at the moment. If you can point me to a specific reference implementation, maybe I can give it a stab.

from torchmetrics.

michael080808 avatar michael080808 commented on May 27, 2024

Hi @michael080808, thanks for opening this issue. We would be more than welcome to receive a pull request with this metric (either a partial implementation or full), but I do not think anyone at the core team has bandwidth or the experience to implement such a complex metric at the moment. If you can point me to a specific reference implementation, maybe I can give it a stab.

There are some implementations as mentioned.

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on May 27, 2024

@michael080808 could you point me to which one of these have the most promising reference implementation? What specific file are the metric implemented in?

from torchmetrics.

michael080808 avatar michael080808 commented on May 27, 2024

@michael080808 could you point me to which one of these have the most promising reference implementation? What specific file are the metric implemented in?

Others are the references for DTCWT implementation. According to my understanding, once Steerable Pyramid or DTCWT has been implemented, CW-SSIM will be very easy to code because CW-SSIM just uses the transform result from SP or DTCWT to calculate the new SSIM definition. Maybe it's very hard to implement DTCWT without some reference codes. The following two are specific for DTCWT implementation.

I hope that this information is helpful. If there are any other questions, I would be pleased to answer them. 🙂❤

from torchmetrics.

michael080808 avatar michael080808 commented on May 27, 2024

I got a quick learn with

and tried a simple version of CW-SSIM. Here are two parts of the codes running with PyTorch 2.2. I did some coordinate changes for better calculation when input width or height is with odd number. I think it should pay more attention with complex convolution support. It's a very new feature and CW-SSIM heavily depends on it. I hope it could be helpful for understand.

#pyramid.py

"""
Put [0, Length - 1] into [-1, 1]
I prefer use pixel center as coordinate position
                  +-----+-----+-----+
                  |     |     |     |
                  |  A  |  B  |  C  |
                  |     |     |     |
+-----+-----+     +-----+-----+-----+
|     |     |     |     |     |     |
|  A  |  B  |     |  D  |  O  |  E  |
|     |     |     |     |     |     |
+-----O-----+     +-----+-----+-----+
|     |     |     |     |     |     |
|  C  |  D  |     |  F  |  G  |  H  |
|     |     |     |     |     |     |
+-----+-----+     +-----+-----+-----+
Here, O is the coordinate origin.
In even amount of pixels situation, A, B, C, D's coordinates are with half values.
In odd  amount of pixels situation, A, B, C, D, E, F, G, H's coordinates are without half values.
"""
import functools
import itertools
import math
import operator
from abc import ABCMeta, abstractmethod
from typing import List, Tuple, Union

import torch.fft
from torch import Tensor
from torch.types import Device


class SteerablePyramid:
    class _Filter(metaclass=ABCMeta):
        @staticmethod
        def bound_convert_2_tuple(boundary: Union[float, Tuple[float], Tuple[float, float]]) -> Tuple[float, float]:
            if isinstance(boundary, float):
                boundary = (boundary,)
            if isinstance(boundary, tuple) and len(boundary) == 1:
                boundary = (boundary[0], boundary[0])
            return boundary[0], boundary[1]

        @staticmethod
        def normalized_lin_spaces(length: int, device: Device = None) -> Tensor:
            number = torch.arange(length, device=device)
            return (number - length / 2 + 0.5) / (length // 2)

        @staticmethod
        def normalized_coordinate(shapes: Tuple[int, int], device: Device = None) -> Tuple[Tensor, ...]:
            coords = [SteerablePyramid._Filter.normalized_lin_spaces(length, device) for length in reversed(shapes)]
            return torch.meshgrid(*list(reversed(coords)), indexing='ij')

        @staticmethod
        def polars(shapes: Tuple[int, int], device: Device = None) -> Tuple[Tensor, Tensor]:
            x, y = SteerablePyramid._Filter.normalized_coordinate(shapes, device)
            return torch.sqrt(x ** 2 + y ** 2), torch.arctan2(y, x)

        @staticmethod
        def angles(shapes: Tuple[int, int], device: Device = None) -> Tensor:
            x, y = SteerablePyramid._Filter.normalized_coordinate(shapes, device)
            return torch.arctan2(y, x)

        @staticmethod
        def radius(shapes: Tuple[int, int], device: Device = None) -> Tensor:
            x, y = SteerablePyramid._Filter.normalized_coordinate(shapes, device)
            return torch.sqrt(x ** 2 + y ** 2)

        @staticmethod
        def bounds(shapes: Tuple[int, int], boundary: Union[float, Tuple[float], Tuple[float, float]], device: Device = None) -> Tensor:
            boundary = SteerablePyramid._Filter.bound_convert_2_tuple(boundary)
            return (boundary[0] * boundary[1]) / torch.sqrt((boundary[0] * torch.cos(SteerablePyramid._Filter.angles(shapes, device))) ** 2 + (boundary[1] * torch.sin(SteerablePyramid._Filter.angles(shapes, device))) ** 2)

        @staticmethod
        def high_band_pass_filter(shapes: Tuple[int, int], boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0, device: Device = None):
            diff = torch.log2(SteerablePyramid._Filter.radius(shapes, device)) - torch.log2(SteerablePyramid._Filter.bounds(shapes, boundary, device))
            return torch.abs(torch.cos((torch.clamp(diff, min=-transition_width, max=0) / transition_width) * (math.pi / 2)))

        @staticmethod
        def bass_band_pass_filter(shapes: Tuple[int, int], boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0, device: Device = None):
            high = SteerablePyramid._Filter.high_band_pass_filter(shapes, boundary=boundary, transition_width=transition_width, device=device)
            return torch.sqrt(1 - high ** 2)

        @abstractmethod
        def __init__(self):
            super().__init__()

        @abstractmethod
        def __call__(self, shapes: Tuple[int, int], device: Device = None):
            raise NotImplementedError

    class BassPassFilter(_Filter):
        def __init__(self, boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0):
            super().__init__()
            self.boundary = boundary
            self.transition_width = transition_width

        def __call__(self, shapes: Tuple[int, int], device: Device = None):
            return self.bass_band_pass_filter(shapes=shapes, boundary=self.boundary, transition_width=self.transition_width, device=device)

    class HighPassFilter(_Filter):
        def __init__(self, boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0):
            super().__init__()
            self.boundary = boundary
            self.transition_width = transition_width

        def __call__(self, shapes: Tuple[int, int], device: Device = None):
            return self.high_band_pass_filter(shapes=shapes, boundary=self.boundary, transition_width=self.transition_width, device=device)

    class BandPassFilter(_Filter):
        def __init__(self, boundary_high: Union[float, Tuple[float], Tuple[float, float]] = 1.0, boundary_bass: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0):
            super().__init__()
            assert functools.reduce(operator.__and__, itertools.starmap(operator.ge, zip((boundary_high,) if isinstance(boundary_high, float) else boundary_high, (boundary_bass,) if isinstance(boundary_bass, float) else boundary_bass))), 'All elements from "boundary_high" must be greater than or equal to the corresponding elements in "boundary_bass".'
            self.boundary_bass, self.boundary_high, self.transition_width = boundary_bass, boundary_high, transition_width

        def __call__(self, shapes: Tuple[int, int], device: Device = None):
            return self.bass_band_pass_filter(shapes=shapes, boundary=self.boundary_high, transition_width=self.transition_width, device=device) * self.high_band_pass_filter(shapes=shapes, boundary=self.boundary_bass, transition_width=self.transition_width, device=device)

    class SteeringFilter(BandPassFilter):
        def __init__(self, boundary_bass: Union[float, Tuple[float], Tuple[float, float]] = 1.0, boundary_high: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0, index: int = 0, orientations: int = 2, support_cplx: bool = False):
            super().__init__(boundary_bass, boundary_high, transition_width)
            assert index < orientations, '"index" must be less than or equal to "orientations".'
            self.index, self.orientations, self.support_cplx = index, orientations, support_cplx

        def __call__(self, shapes: Tuple[int, int], device: Device = None):
            return super().__call__(shapes, device) * self.orientation_filter(shapes, self.support_cplx, device)

        @property
        def constant(self):
            order = self.orientations - 1
            return math.pow(2, (2 * order)) * math.pow(math.factorial(order), 2) / (self.orientations * math.factorial(2 * order))

        def orientation_filter(self, shapes: Tuple[int, int], u4cplx: bool = False, device: Device = None):
            angles = torch.remainder(math.pi + self.angles(shapes, device) - math.pi * self.index / self.orientations, 2 * math.pi) - math.pi
            return (torch.abs(math.sqrt(self.constant) * torch.pow(torch.cos(angles), self.orientations - 1))) * (torch.lt(torch.abs(angles), math.pi / 2) if u4cplx else 1)

    @staticmethod
    def to_freq_domain(x: Tensor) -> Tensor:
        assert x.dim() >= 2, 'Not enough dimensions to run "to_freq_domain" procedure.'
        return torch.fft.fftshift(torch.fft.fft2(x, dim=[-2, -1]), dim=[-2, -1])

    @staticmethod
    def to_time_domain(x: Tensor) -> Tensor:
        assert x.dim() >= 2, 'Not enough dimensions to run "to_time_domain" procedure.'
        return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-2, -1]), dim=[-2, -1])

    @staticmethod
    def to_crop_region(entire: Tuple[int, int], region: Tuple[int, int]) -> Tuple[List[int], ...]:
        assert functools.reduce(operator.__and__, itertools.starmap(operator.ge, zip(entire, region))) and functools.reduce(operator.__and__, itertools.starmap(lambda x, y: (x - y) % 2 == 0, zip(entire, region))), 'All elements from "shapes" must be greater than or equal to the corresponding elements in "region".'
        return tuple([(shape - focal) // 2, focal, (shape - focal) // 2] for shape, focal in zip(entire, region))

    @staticmethod
    def to_crop_tensor(inputs: Tensor, region: Tuple[int, int]) -> Tensor:
        splits = SteerablePyramid.to_crop_region(entire=(inputs.shape[-2], inputs.shape[-1]), region=region)
        return torch.split(torch.split(inputs, splits[-1], dim=-1)[1], splits[-2], dim=-2)[1]

    @staticmethod
    def to_join_tensor(fronts: Tensor, backed: Tensor) -> Tensor:
        assert fronts.dim() == backed.dim() >= 2 and fronts.shape[-1] < backed.shape[-1] and fronts.shape[-2] < backed.shape[-2] and fronts.shape[:-2] == backed.shape[:-2], 'Unable to join two tensors into one due to the shape mismatch.'
        return torch.nn.functional.pad(fronts, [(backed.shape[-1] - fronts.shape[-1]) // 2] * 2 + [(backed.shape[-2] - fronts.shape[-2]) // 2] * 2, mode='constant', value=0) + backed * torch.nn.functional.pad(torch.zeros_like(fronts), [(backed.shape[-1] - fronts.shape[-1]) // 2] * 2 + [(backed.shape[-2] - fronts.shape[-2]) // 2] * 2, mode='constant', value=1)

    def __init__(self, group_levels: int = 6, orientations: int = 16, support_cplx: bool = True, transition_w: float = 1.0):
        super().__init__()
        self.group_levels = group_levels
        self.orientations = orientations
        self.support_cplx = support_cplx
        self.transition_w = transition_w

    def region_iteration(self, shapes: Tuple[int, int]):
        last = shapes
        yield last
        for i in range(self.group_levels):
            last = (last[0] - (last[0] // 4) * 2, last[1] - (last[1] // 4) * 2)
            yield last

    def factor_iteration(self, shapes: Tuple[int, int]):
        last = shapes
        yield tuple(itertools.starmap(operator.truediv, zip(last, shapes)))
        for _ in range(self.group_levels):
            last = (last[0] - (last[0] // 4) * 2, last[1] - (last[1] // 4) * 2)
            yield tuple(itertools.starmap(operator.truediv, zip(last, shapes)))

    def filter_iteration(self, shapes: Tuple[int, int], device: Device = None):
        iteration = zip(itertools.pairwise(self.factor_iteration(shapes)), self.region_iteration(shapes))
        for level, ((prev_f, curr_f), region) in enumerate(iteration):
            if level == 0:
                yield self.to_crop_tensor(self.HighPassFilter(boundary=prev_f, transition_width=self.transition_w)(shapes, device), region), f'H{level}'
                yield self.to_crop_tensor(self.BassPassFilter(boundary=prev_f, transition_width=self.transition_w)(shapes, device), region), f'L{level}'
            for orientation in range(self.orientations):
                yield self.to_crop_tensor(self.SteeringFilter(boundary_bass=prev_f, boundary_high=curr_f, transition_width=self.transition_w, index=orientation, orientations=self.orientations, support_cplx=self.support_cplx)(shapes, device), region), f'B{level + 1}o{orientation}'
            yield self.to_crop_tensor(self.BassPassFilter(boundary=curr_f, transition_width=self.transition_w)(shapes, device), region), f'L{level + 1}'

    def encode_iteration(self, tensor: Tensor):
        shapes = (tensor.shape[-2], tensor.shape[-1])
        target, window = self.to_freq_domain(tensor), None
        it_filter = self.filter_iteration(shapes, tensor.device)

        # L0 HighPass Output
        window = next(it_filter)
        time_domain = self.to_time_domain(target * window[0])
        yield time_domain if self.support_cplx else torch.real(time_domain), window[1]

        # L0 BassPass Remove
        window = next(it_filter)
        target = target * window[0]
        # yield time_domain if self.support_cplx else torch.real(time_domain), window[1] <- Removed due to definition.

        # Each Level Steering BandPass
        for level, (curr_r, next_r) in enumerate(itertools.pairwise(self.region_iteration(shapes))):
            for orientation in range(self.orientations):
                window = next(it_filter)
                time_domain = self.to_time_domain(target * window[0])
                yield time_domain if self.support_cplx else torch.real(time_domain), window[1]
            window = next(it_filter)
            target = self.to_crop_tensor(target * window[0], next_r)

        # Final BassPass
        time_domain = self.to_time_domain(target)
        yield time_domain if self.support_cplx else torch.real(time_domain), window[1]
# main.py

from typing import Tuple

import skimage
import torch
from skimage.data import astronaut
from torch import Tensor
from torch.nn import Module

from pyramid import SteerablePyramid


class CwSSIM(Module):
    result_pyramid = SteerablePyramid()
    ground_pyramid = SteerablePyramid()

    def __init__(self, kernel: int = 7, k: float = 0, levels: int = 6, orientations: int = 16, transition_w: float = 1.0):
        super().__init__()
        self.k = k
        self.kernel = torch.ones([kernel] * 2)[None, None, ...]
        self.result_pyramid = SteerablePyramid(group_levels=levels, orientations=orientations, transition_w=transition_w, support_cplx=True)
        self.ground_pyramid = SteerablePyramid(group_levels=levels, orientations=orientations, transition_w=transition_w, support_cplx=True)

    def multidim_conv2d(self, inputs: Tensor, *args, **kwargs) -> Tensor:
        if inputs.dim() <= 1:
            raise ValueError('One Dimensional Input is not supported.')
        channels = inputs.shape[-3] if inputs.dim() >= 3 else 1
        paddings = [self.kernel.size(dim) // 2 for _ in range(2) for dim in [-1, -2]]
        groups = args[4] if len(args) >= 5 else kwargs.get('groups', channels)
        kwargs['groups'] = groups
        shapes = inputs.shape
        kernel = self.kernel.repeat(1, 1, 1, 1) if inputs.dim() == 2 else self.kernel.repeat(channels, channels // groups, 1, 1)

        if inputs.dim() >= 5:
            return torch.nn.functional.conv2d(torch.nn.functional.pad(inputs.flatten(0x0, -0x4), paddings, mode='reflect'), kernel.to(device=inputs.device, dtype=inputs.dtype), *args, **kwargs).unflatten(0, shapes[:-3])
        if 2 <= inputs.dim() <= 4:
            return torch.nn.functional.conv2d(torch.nn.functional.pad(inputs.repeat(1, 1, 1, 1), paddings, mode='reflect'), kernel.to(device=inputs.device, dtype=inputs.dtype), *args, **kwargs).squeeze(tuple(range(0, 4 - inputs.dim())))

    def statistics(self, result: Tensor, ground: Tensor, *args, **kwargs) -> Tuple[Tensor, ...]:
        conj_prods = self.multidim_conv2d(result * torch.conj(ground), *args, **kwargs)
        sum_mod_sq = self.multidim_conv2d(torch.abs(result) ** 2 + torch.abs(ground) ** 2, *args, **kwargs)
        return conj_prods, sum_mod_sq

    def forward(self, result: Tensor, ground: Tensor, *args, **kwargs) -> Tensor:
        assert result.shape == ground.shape
        result_encode_iter = self.__class__.result_pyramid.encode_iteration(result)
        ground_encode_iter = self.__class__.ground_pyramid.encode_iteration(ground)

        count, summarized = 0, torch.zeros(1)
        for (result_encode, _), (ground_encode, _) in zip(result_encode_iter, ground_encode_iter):
            conj_prods, sum_mod_sq = self.statistics(result_encode, ground_encode, *args, **kwargs)
            _ssim = (2 * torch.abs(conj_prods) + self.k) / (sum_mod_sq + self.k)
            count, summarized = count + 1, summarized + torch.mean(_ssim, dim=[-2, -1], keepdim=True)
        return summarized / count


ssim = CwSSIM()
if __name__ == '__main__':
    image = skimage.util.img_as_float32(astronaut())
    noise = skimage.util.random_noise(image, mode='speckle')
    print(ssim(torch.tensor(image).movedim(-1, 0), torch.tensor(noise).movedim(-1, 0)))

    image = skimage.util.img_as_float32(astronaut())
    noise = 0.8 * image
    print(ssim(torch.tensor(image).movedim(-1, 0), torch.tensor(noise).movedim(-1, 0)))

    image = skimage.util.img_as_float32(astronaut())
    noise = skimage.transform.rotate(image, 1)
    print(ssim(torch.tensor(image).movedim(-1, 0), torch.tensor(noise).movedim(-1, 0)))
# Result of CW-SSIM
tensor([[[0.8367]], [[0.8667]], [[0.8853]]], dtype=torch.float64)
tensor([[[0.9756]], [[0.9756]], [[0.9756]]])
tensor([[[0.8905]], [[0.8921]], [[0.8893]]])

from torchmetrics.

SkafteNicki avatar SkafteNicki commented on May 27, 2024

@michael080808 is what you posted a full implementation? It does not seem to rely on any third party package?

from torchmetrics.

michael080808 avatar michael080808 commented on May 27, 2024

@michael080808 is what you posted a full implementation? It does not seem to rely on any third party package?

It's a relatively full implementation. I did not write update and compute procedure in torchmetrics. I just rewrite SteerablePyramid Method to meet my requirement. So it does not rely on any other packages.

from torchmetrics.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.