Comments (8)
@michael080808 thanks for providing this overview, it really helps me.
So it seems to me that the best cause of action is to take the implementation from https://github.com/dingkeyan93/IQA-optimization/blob/master/IQA_pytorch/CW_SSIM.py, make sure it works with newer versions of pytorch and then replace the backend to use https://github.com/LabForComputationalVision/plenoptic/blob/main/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py instead. I think that is definitely doable, but I have to look closer at the code and understand the metric a bit better.
from torchmetrics.
Hi @michael080808, thanks for opening this issue. We would be more than welcome to receive a pull request with this metric (either a partial implementation or full), but I do not think anyone at the core team has bandwidth or the experience to implement such a complex metric at the moment. If you can point me to a specific reference implementation, maybe I can give it a stab.
from torchmetrics.
Hi @michael080808, thanks for opening this issue. We would be more than welcome to receive a pull request with this metric (either a partial implementation or full), but I do not think anyone at the core team has bandwidth or the experience to implement such a complex metric at the moment. If you can point me to a specific reference implementation, maybe I can give it a stab.
There are some implementations as mentioned.
- https://github.com/jterrace/pyssim
- https://github.com/LeonArcher/py_cwt2d
- https://github.com/dingkeyan93/IQA-optimization
- https://github.com/fbcotter/pytorch_wavelets
- https://github.com/LabForComputationalVision/pyrtools
- https://github.com/LabForComputationalVision/plenoptic
from torchmetrics.
@michael080808 could you point me to which one of these have the most promising reference implementation? What specific file are the metric implemented in?
from torchmetrics.
@michael080808 could you point me to which one of these have the most promising reference implementation? What specific file are the metric implemented in?
-
https://github.com/dingkeyan93/IQA-optimization/blob/master/IQA_pytorch/CW_SSIM.py
This one gives a full functional CW-SSIM with PyTorch 1.6 support. I just tested on 2.2 and it do not work well.
It uses part of the https://github.com/LabForComputationalVision/pyrtools as Steerable Pyramid backend.
As https://github.com/LabForComputationalVision/pyrtools recommended, it's better to use new https://github.com/LabForComputationalVision/plenoptic as new Complex Steerable Pyramid backend.
https://github.com/LabForComputationalVision/plenoptic/blob/main/src/plenoptic/simulate/canonical_computations/steerable_pyramid_freq.py gives a detailed Steerable Pyramid module. Here is a explain about how Steerable Pyramid works with more detailed codes. https://medium.com/@itberrios6/steerable-pyramids-6bfd4d23c10d -
https://github.com/jterrace/pyssim/blob/e4f864c4656d2b0041bd908faed8e32f72ace31e/ssim/ssimlib.py#L147
also gives a implementation withscipy.signal.cwt
but I think it is not a proper implementation of the CW-SSIM due to the defination ofscipy.signal.cwt
. It's not a complex wavelet transform but a continuous wavelet transform according to https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.cwt.html definition.
Others are the references for DTCWT implementation. According to my understanding, once Steerable Pyramid or DTCWT has been implemented, CW-SSIM will be very easy to code because CW-SSIM just uses the transform result from SP or DTCWT to calculate the new SSIM definition. Maybe it's very hard to implement DTCWT without some reference codes. The following two are specific for DTCWT implementation.
- https://github.com/LeonArcher/py_cwt2d
- https://github.com/fbcotter/pytorch_wavelets/tree/master/pytorch_wavelets/dtcwt
I hope that this information is helpful. If there are any other questions, I would be pleased to answer them. 🙂❤
from torchmetrics.
I got a quick learn with
- https://medium.com/@itberrios6/steerable-pyramids-6bfd4d23c10d
- https://medium.com/@itberrios6/complex-steerable-pyramids-3cf7b99ff9fc
and tried a simple version of CW-SSIM. Here are two parts of the codes running with PyTorch 2.2. I did some coordinate changes for better calculation when input width or height is with odd number. I think it should pay more attention with complex convolution support. It's a very new feature and CW-SSIM heavily depends on it. I hope it could be helpful for understand.
#pyramid.py
"""
Put [0, Length - 1] into [-1, 1]
I prefer use pixel center as coordinate position
+-----+-----+-----+
| | | |
| A | B | C |
| | | |
+-----+-----+ +-----+-----+-----+
| | | | | | |
| A | B | | D | O | E |
| | | | | | |
+-----O-----+ +-----+-----+-----+
| | | | | | |
| C | D | | F | G | H |
| | | | | | |
+-----+-----+ +-----+-----+-----+
Here, O is the coordinate origin.
In even amount of pixels situation, A, B, C, D's coordinates are with half values.
In odd amount of pixels situation, A, B, C, D, E, F, G, H's coordinates are without half values.
"""
import functools
import itertools
import math
import operator
from abc import ABCMeta, abstractmethod
from typing import List, Tuple, Union
import torch.fft
from torch import Tensor
from torch.types import Device
class SteerablePyramid:
class _Filter(metaclass=ABCMeta):
@staticmethod
def bound_convert_2_tuple(boundary: Union[float, Tuple[float], Tuple[float, float]]) -> Tuple[float, float]:
if isinstance(boundary, float):
boundary = (boundary,)
if isinstance(boundary, tuple) and len(boundary) == 1:
boundary = (boundary[0], boundary[0])
return boundary[0], boundary[1]
@staticmethod
def normalized_lin_spaces(length: int, device: Device = None) -> Tensor:
number = torch.arange(length, device=device)
return (number - length / 2 + 0.5) / (length // 2)
@staticmethod
def normalized_coordinate(shapes: Tuple[int, int], device: Device = None) -> Tuple[Tensor, ...]:
coords = [SteerablePyramid._Filter.normalized_lin_spaces(length, device) for length in reversed(shapes)]
return torch.meshgrid(*list(reversed(coords)), indexing='ij')
@staticmethod
def polars(shapes: Tuple[int, int], device: Device = None) -> Tuple[Tensor, Tensor]:
x, y = SteerablePyramid._Filter.normalized_coordinate(shapes, device)
return torch.sqrt(x ** 2 + y ** 2), torch.arctan2(y, x)
@staticmethod
def angles(shapes: Tuple[int, int], device: Device = None) -> Tensor:
x, y = SteerablePyramid._Filter.normalized_coordinate(shapes, device)
return torch.arctan2(y, x)
@staticmethod
def radius(shapes: Tuple[int, int], device: Device = None) -> Tensor:
x, y = SteerablePyramid._Filter.normalized_coordinate(shapes, device)
return torch.sqrt(x ** 2 + y ** 2)
@staticmethod
def bounds(shapes: Tuple[int, int], boundary: Union[float, Tuple[float], Tuple[float, float]], device: Device = None) -> Tensor:
boundary = SteerablePyramid._Filter.bound_convert_2_tuple(boundary)
return (boundary[0] * boundary[1]) / torch.sqrt((boundary[0] * torch.cos(SteerablePyramid._Filter.angles(shapes, device))) ** 2 + (boundary[1] * torch.sin(SteerablePyramid._Filter.angles(shapes, device))) ** 2)
@staticmethod
def high_band_pass_filter(shapes: Tuple[int, int], boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0, device: Device = None):
diff = torch.log2(SteerablePyramid._Filter.radius(shapes, device)) - torch.log2(SteerablePyramid._Filter.bounds(shapes, boundary, device))
return torch.abs(torch.cos((torch.clamp(diff, min=-transition_width, max=0) / transition_width) * (math.pi / 2)))
@staticmethod
def bass_band_pass_filter(shapes: Tuple[int, int], boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0, device: Device = None):
high = SteerablePyramid._Filter.high_band_pass_filter(shapes, boundary=boundary, transition_width=transition_width, device=device)
return torch.sqrt(1 - high ** 2)
@abstractmethod
def __init__(self):
super().__init__()
@abstractmethod
def __call__(self, shapes: Tuple[int, int], device: Device = None):
raise NotImplementedError
class BassPassFilter(_Filter):
def __init__(self, boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0):
super().__init__()
self.boundary = boundary
self.transition_width = transition_width
def __call__(self, shapes: Tuple[int, int], device: Device = None):
return self.bass_band_pass_filter(shapes=shapes, boundary=self.boundary, transition_width=self.transition_width, device=device)
class HighPassFilter(_Filter):
def __init__(self, boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0):
super().__init__()
self.boundary = boundary
self.transition_width = transition_width
def __call__(self, shapes: Tuple[int, int], device: Device = None):
return self.high_band_pass_filter(shapes=shapes, boundary=self.boundary, transition_width=self.transition_width, device=device)
class BandPassFilter(_Filter):
def __init__(self, boundary_high: Union[float, Tuple[float], Tuple[float, float]] = 1.0, boundary_bass: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0):
super().__init__()
assert functools.reduce(operator.__and__, itertools.starmap(operator.ge, zip((boundary_high,) if isinstance(boundary_high, float) else boundary_high, (boundary_bass,) if isinstance(boundary_bass, float) else boundary_bass))), 'All elements from "boundary_high" must be greater than or equal to the corresponding elements in "boundary_bass".'
self.boundary_bass, self.boundary_high, self.transition_width = boundary_bass, boundary_high, transition_width
def __call__(self, shapes: Tuple[int, int], device: Device = None):
return self.bass_band_pass_filter(shapes=shapes, boundary=self.boundary_high, transition_width=self.transition_width, device=device) * self.high_band_pass_filter(shapes=shapes, boundary=self.boundary_bass, transition_width=self.transition_width, device=device)
class SteeringFilter(BandPassFilter):
def __init__(self, boundary_bass: Union[float, Tuple[float], Tuple[float, float]] = 1.0, boundary_high: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0, index: int = 0, orientations: int = 2, support_cplx: bool = False):
super().__init__(boundary_bass, boundary_high, transition_width)
assert index < orientations, '"index" must be less than or equal to "orientations".'
self.index, self.orientations, self.support_cplx = index, orientations, support_cplx
def __call__(self, shapes: Tuple[int, int], device: Device = None):
return super().__call__(shapes, device) * self.orientation_filter(shapes, self.support_cplx, device)
@property
def constant(self):
order = self.orientations - 1
return math.pow(2, (2 * order)) * math.pow(math.factorial(order), 2) / (self.orientations * math.factorial(2 * order))
def orientation_filter(self, shapes: Tuple[int, int], u4cplx: bool = False, device: Device = None):
angles = torch.remainder(math.pi + self.angles(shapes, device) - math.pi * self.index / self.orientations, 2 * math.pi) - math.pi
return (torch.abs(math.sqrt(self.constant) * torch.pow(torch.cos(angles), self.orientations - 1))) * (torch.lt(torch.abs(angles), math.pi / 2) if u4cplx else 1)
@staticmethod
def to_freq_domain(x: Tensor) -> Tensor:
assert x.dim() >= 2, 'Not enough dimensions to run "to_freq_domain" procedure.'
return torch.fft.fftshift(torch.fft.fft2(x, dim=[-2, -1]), dim=[-2, -1])
@staticmethod
def to_time_domain(x: Tensor) -> Tensor:
assert x.dim() >= 2, 'Not enough dimensions to run "to_time_domain" procedure.'
return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-2, -1]), dim=[-2, -1])
@staticmethod
def to_crop_region(entire: Tuple[int, int], region: Tuple[int, int]) -> Tuple[List[int], ...]:
assert functools.reduce(operator.__and__, itertools.starmap(operator.ge, zip(entire, region))) and functools.reduce(operator.__and__, itertools.starmap(lambda x, y: (x - y) % 2 == 0, zip(entire, region))), 'All elements from "shapes" must be greater than or equal to the corresponding elements in "region".'
return tuple([(shape - focal) // 2, focal, (shape - focal) // 2] for shape, focal in zip(entire, region))
@staticmethod
def to_crop_tensor(inputs: Tensor, region: Tuple[int, int]) -> Tensor:
splits = SteerablePyramid.to_crop_region(entire=(inputs.shape[-2], inputs.shape[-1]), region=region)
return torch.split(torch.split(inputs, splits[-1], dim=-1)[1], splits[-2], dim=-2)[1]
@staticmethod
def to_join_tensor(fronts: Tensor, backed: Tensor) -> Tensor:
assert fronts.dim() == backed.dim() >= 2 and fronts.shape[-1] < backed.shape[-1] and fronts.shape[-2] < backed.shape[-2] and fronts.shape[:-2] == backed.shape[:-2], 'Unable to join two tensors into one due to the shape mismatch.'
return torch.nn.functional.pad(fronts, [(backed.shape[-1] - fronts.shape[-1]) // 2] * 2 + [(backed.shape[-2] - fronts.shape[-2]) // 2] * 2, mode='constant', value=0) + backed * torch.nn.functional.pad(torch.zeros_like(fronts), [(backed.shape[-1] - fronts.shape[-1]) // 2] * 2 + [(backed.shape[-2] - fronts.shape[-2]) // 2] * 2, mode='constant', value=1)
def __init__(self, group_levels: int = 6, orientations: int = 16, support_cplx: bool = True, transition_w: float = 1.0):
super().__init__()
self.group_levels = group_levels
self.orientations = orientations
self.support_cplx = support_cplx
self.transition_w = transition_w
def region_iteration(self, shapes: Tuple[int, int]):
last = shapes
yield last
for i in range(self.group_levels):
last = (last[0] - (last[0] // 4) * 2, last[1] - (last[1] // 4) * 2)
yield last
def factor_iteration(self, shapes: Tuple[int, int]):
last = shapes
yield tuple(itertools.starmap(operator.truediv, zip(last, shapes)))
for _ in range(self.group_levels):
last = (last[0] - (last[0] // 4) * 2, last[1] - (last[1] // 4) * 2)
yield tuple(itertools.starmap(operator.truediv, zip(last, shapes)))
def filter_iteration(self, shapes: Tuple[int, int], device: Device = None):
iteration = zip(itertools.pairwise(self.factor_iteration(shapes)), self.region_iteration(shapes))
for level, ((prev_f, curr_f), region) in enumerate(iteration):
if level == 0:
yield self.to_crop_tensor(self.HighPassFilter(boundary=prev_f, transition_width=self.transition_w)(shapes, device), region), f'H{level}'
yield self.to_crop_tensor(self.BassPassFilter(boundary=prev_f, transition_width=self.transition_w)(shapes, device), region), f'L{level}'
for orientation in range(self.orientations):
yield self.to_crop_tensor(self.SteeringFilter(boundary_bass=prev_f, boundary_high=curr_f, transition_width=self.transition_w, index=orientation, orientations=self.orientations, support_cplx=self.support_cplx)(shapes, device), region), f'B{level + 1}o{orientation}'
yield self.to_crop_tensor(self.BassPassFilter(boundary=curr_f, transition_width=self.transition_w)(shapes, device), region), f'L{level + 1}'
def encode_iteration(self, tensor: Tensor):
shapes = (tensor.shape[-2], tensor.shape[-1])
target, window = self.to_freq_domain(tensor), None
it_filter = self.filter_iteration(shapes, tensor.device)
# L0 HighPass Output
window = next(it_filter)
time_domain = self.to_time_domain(target * window[0])
yield time_domain if self.support_cplx else torch.real(time_domain), window[1]
# L0 BassPass Remove
window = next(it_filter)
target = target * window[0]
# yield time_domain if self.support_cplx else torch.real(time_domain), window[1] <- Removed due to definition.
# Each Level Steering BandPass
for level, (curr_r, next_r) in enumerate(itertools.pairwise(self.region_iteration(shapes))):
for orientation in range(self.orientations):
window = next(it_filter)
time_domain = self.to_time_domain(target * window[0])
yield time_domain if self.support_cplx else torch.real(time_domain), window[1]
window = next(it_filter)
target = self.to_crop_tensor(target * window[0], next_r)
# Final BassPass
time_domain = self.to_time_domain(target)
yield time_domain if self.support_cplx else torch.real(time_domain), window[1]
# main.py
from typing import Tuple
import skimage
import torch
from skimage.data import astronaut
from torch import Tensor
from torch.nn import Module
from pyramid import SteerablePyramid
class CwSSIM(Module):
result_pyramid = SteerablePyramid()
ground_pyramid = SteerablePyramid()
def __init__(self, kernel: int = 7, k: float = 0, levels: int = 6, orientations: int = 16, transition_w: float = 1.0):
super().__init__()
self.k = k
self.kernel = torch.ones([kernel] * 2)[None, None, ...]
self.result_pyramid = SteerablePyramid(group_levels=levels, orientations=orientations, transition_w=transition_w, support_cplx=True)
self.ground_pyramid = SteerablePyramid(group_levels=levels, orientations=orientations, transition_w=transition_w, support_cplx=True)
def multidim_conv2d(self, inputs: Tensor, *args, **kwargs) -> Tensor:
if inputs.dim() <= 1:
raise ValueError('One Dimensional Input is not supported.')
channels = inputs.shape[-3] if inputs.dim() >= 3 else 1
paddings = [self.kernel.size(dim) // 2 for _ in range(2) for dim in [-1, -2]]
groups = args[4] if len(args) >= 5 else kwargs.get('groups', channels)
kwargs['groups'] = groups
shapes = inputs.shape
kernel = self.kernel.repeat(1, 1, 1, 1) if inputs.dim() == 2 else self.kernel.repeat(channels, channels // groups, 1, 1)
if inputs.dim() >= 5:
return torch.nn.functional.conv2d(torch.nn.functional.pad(inputs.flatten(0x0, -0x4), paddings, mode='reflect'), kernel.to(device=inputs.device, dtype=inputs.dtype), *args, **kwargs).unflatten(0, shapes[:-3])
if 2 <= inputs.dim() <= 4:
return torch.nn.functional.conv2d(torch.nn.functional.pad(inputs.repeat(1, 1, 1, 1), paddings, mode='reflect'), kernel.to(device=inputs.device, dtype=inputs.dtype), *args, **kwargs).squeeze(tuple(range(0, 4 - inputs.dim())))
def statistics(self, result: Tensor, ground: Tensor, *args, **kwargs) -> Tuple[Tensor, ...]:
conj_prods = self.multidim_conv2d(result * torch.conj(ground), *args, **kwargs)
sum_mod_sq = self.multidim_conv2d(torch.abs(result) ** 2 + torch.abs(ground) ** 2, *args, **kwargs)
return conj_prods, sum_mod_sq
def forward(self, result: Tensor, ground: Tensor, *args, **kwargs) -> Tensor:
assert result.shape == ground.shape
result_encode_iter = self.__class__.result_pyramid.encode_iteration(result)
ground_encode_iter = self.__class__.ground_pyramid.encode_iteration(ground)
count, summarized = 0, torch.zeros(1)
for (result_encode, _), (ground_encode, _) in zip(result_encode_iter, ground_encode_iter):
conj_prods, sum_mod_sq = self.statistics(result_encode, ground_encode, *args, **kwargs)
_ssim = (2 * torch.abs(conj_prods) + self.k) / (sum_mod_sq + self.k)
count, summarized = count + 1, summarized + torch.mean(_ssim, dim=[-2, -1], keepdim=True)
return summarized / count
ssim = CwSSIM()
if __name__ == '__main__':
image = skimage.util.img_as_float32(astronaut())
noise = skimage.util.random_noise(image, mode='speckle')
print(ssim(torch.tensor(image).movedim(-1, 0), torch.tensor(noise).movedim(-1, 0)))
image = skimage.util.img_as_float32(astronaut())
noise = 0.8 * image
print(ssim(torch.tensor(image).movedim(-1, 0), torch.tensor(noise).movedim(-1, 0)))
image = skimage.util.img_as_float32(astronaut())
noise = skimage.transform.rotate(image, 1)
print(ssim(torch.tensor(image).movedim(-1, 0), torch.tensor(noise).movedim(-1, 0)))
# Result of CW-SSIM
tensor([[[0.8367]], [[0.8667]], [[0.8853]]], dtype=torch.float64)
tensor([[[0.9756]], [[0.9756]], [[0.9756]]])
tensor([[[0.8905]], [[0.8921]], [[0.8893]]])
from torchmetrics.
@michael080808 is what you posted a full implementation? It does not seem to rely on any third party package?
from torchmetrics.
@michael080808 is what you posted a full implementation? It does not seem to rely on any third party package?
It's a relatively full implementation. I did not write update
and compute
procedure in torchmetrics
. I just rewrite SteerablePyramid Method to meet my requirement. So it does not rely on any other packages.
from torchmetrics.
Related Issues (20)
- Can't access metrics in a MetricCollection via keys returned in MetricCollection.keys
- `log_dict` method breaks on `MultitaskWrapper` + `MetricCollection` combination HOT 2
- Incorrect value of AUROC when plotting a `PrecisionRecallCurve` metric with `score=True` HOT 2
- A way to run *any* metric async on cpu HOT 1
- Unable to plot `MetricCollection` containing prefix using `MetricCollection.plot()` HOT 1
- A new metric NPV HOT 1
- Error in ERGAS metric HOT 5
- top-k multiclass macro accuracy less than top-1 multiclass macro accuracy HOT 2
- top-k multiclass macro accuracy is not calculated correctly
- Error during argument validation: predictions can not contain `ignore_index` HOT 1
- Create gallery of realistic examples
- Delay imports of optional dependencies such as torchaudio, torchvision HOT 1
- Incorrect result in computing `MulticlassRecall` macro average when `ignore_index` is specified HOT 1
- RetrievalNormalizedDCG doesn't change with different top_k values HOT 2
- BootStrapper.update/forward don't process kwargs HOT 1
- List Metric synchronization fails in corner case HOT 1
- Contribution: Add new audio/speech metrics for generative audio HOT 4
- ClasswiseWrapper and JaccardIndex confmat attribute error HOT 2
- MulticlassAveragePrecision crashes on .compute() if empty HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchmetrics.