Skip to content

Commit

Permalink
Adding documentation for new metrics (#368)
Browse files Browse the repository at this point in the history
* docs: add missing entries

* docs: fix acronym

* docs: fix readme

* docs: formatting

* Removed old documentation deployment

* docs: move to NR

* docs: adding missing examples

* docs: fix docstring for CLIPIQA

* fix: tests for clipiqa
  • Loading branch information
denproc authored Jul 4, 2023
1 parent c26ce7c commit 81ac1e0
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 58 deletions.
19 changes: 0 additions & 19 deletions .github/workflows/cd-documentation.yml

This file was deleted.

9 changes: 5 additions & 4 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,13 @@ VIFp 2004 `Visual Information Fidelity <https://ieeexplore.ieee.org/d
FSIM 2011 `Feature Similarity Index Measure <https://ieeexplore.ieee.org/document/5705575>`_
SR-SIM 2012 `Spectral Residual Based Similarity <https://sse.tongji.edu.cn/linzhang/ICIP12/ICIP-SR-SIM.pdf>`_
GMSD 2013 `Gradient Magnitude Similarity Deviation <https://arxiv.org/abs/1308.3052>`_
MS-GMSD 2017 `Multi-Scale Gradient Magnitude Similarity Deviation <https://ieeexplore.ieee.org/document/7952357>`_
VSI 2014 `Visual Saliency-induced Index <https://ieeexplore.ieee.org/document/6873260>`_
DSS 2015 `DCT Subband Similarity Index <https://ieeexplore.ieee.org/document/7351172>`_
\- 2016 `Content Score <https://arxiv.org/abs/1508.06576>`_
\- 2016 `Style Score <https://arxiv.org/abs/1508.06576>`_
HaarPSI 2016 `Haar Perceptual Similarity Index <https://arxiv.org/abs/1607.06140>`_
MDSI 2016 `Mean Deviation Similarity Index <https://arxiv.org/abs/1608.07433>`_
MS-GMSD 2017 `Multi-Scale Gradient Magnitude Similarity Deviation <https://ieeexplore.ieee.org/document/7952357>`_
LPIPS 2018 `Learned Perceptual Image Patch Similarity <https://arxiv.org/abs/1801.03924>`_
PieAPP 2018 `Perceptual Image-Error Assessment through Pairwise Preference <https://arxiv.org/abs/1806.02067>`_
DISTS 2020 `Deep Image Structure and Texture Similarity <https://arxiv.org/abs/2004.07728>`_
Expand All @@ -195,6 +195,7 @@ Acronym Year Metric
=========== ====== ==========
TV 1937 `Total Variation <https://en.wikipedia.org/wiki/Total_variation>`_
BRISQUE 2012 `Blind/Referenceless Image Spatial Quality Evaluator <https://ieeexplore.ieee.org/document/6272356>`_
CLIP-IQA 2022 `CLIP-IQA <https://arxiv.org/pdf/2207.12396.pdf>`_
=========== ====== ==========

Distribution-Based (DB)
Expand Down Expand Up @@ -277,11 +278,11 @@ GS 0.37 / - 0.37 / - 0.02 / -
No-Reference (NR) Datasets
^^^^^^^^^^^^^^^^^^^^^^^^^^
=========== =========================== ===========================
\ KonIQ10k LIVE-itW
\ KonIQ10k LIVE-itW
----------- --------------------------- ---------------------------
Source PIQ / Reference PIQ / Reference
Source PIQ / Reference PIQ / Reference
=========== =========================== ===========================
BRISQUE 0.22 / - 0.31 / -
BRISQUE 0.22 / - 0.31 / -
CLIP-IQA 0.68 / 0.68 `CLIP-IQA off`_ 0.64 / 0.64 `CLIP-IQA off`_
=========== =========================== ===========================

Expand Down
8 changes: 8 additions & 0 deletions docs/source/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ Feature Similarity Index Measure (FSIM)
'''''''''''''''''''''''''''''''''''''''
.. autofunction:: piq.fsim

Spectral Residual based Similarity (SR-SIM)
'''''''''''''''''''''''''''''''''''''''''''
.. autofunction:: piq.srsim

Gradient Magnitude Similarity Deviation (GMSD)
''''''''''''''''''''''''''''''''''''''''''''''
.. autofunction:: piq.gmsd
Expand All @@ -39,6 +43,10 @@ Visual Saliency-induced Index (VSI)
'''''''''''''''''''''''''''''''''''
.. autofunction:: piq.vsi

DCT Subband Similarity (DSS)
''''''''''''''''''''''''''''
.. autofunction:: piq.dss

Haar Perceptual Similarity Index (HaarPSI)
''''''''''''''''''''''''''''''''''''''''''
.. autofunction:: piq.haarpsi
Expand Down
20 changes: 18 additions & 2 deletions docs/source/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ Class Interface

Full Reference Metrics
----------------------
Structural Similarity
'''''''''''''''''''''
Structural Similarity (SSIM)
''''''''''''''''''''''''''''
.. autoclass:: piq.SSIMLoss
:members:

Expand All @@ -28,6 +28,11 @@ Feature Similarity Index Measure (FSIM)
.. autoclass:: piq.FSIMLoss
:members:

Spectral Residual based Similarity Measure (SR-SIM)
'''''''''''''''''''''''''''''''''''''''''''''''''''
.. autoclass:: piq.SRSIMLoss
:members:

Gradient Magnitude Similarity Deviation (GMSD)
''''''''''''''''''''''''''''''''''''''''''''''
.. autoclass:: piq.GMSDLoss
Expand All @@ -43,6 +48,11 @@ Visual Saliency-induced Index (VSI)
.. autoclass:: piq.VSILoss
:members:

DCT Subband Similarity Index (DSS)
'''''''''''''''''''''''''''''''''''
.. autoclass:: piq.DSSLoss
:members:

Haar Perceptual Similarity Index (HaarPSI)
''''''''''''''''''''''''''''''''''''''''''
.. autoclass:: piq.HaarPSILoss
Expand Down Expand Up @@ -91,6 +101,12 @@ Blind/Referenceless Image Spatial Quality Evaluator (BRISQUE)
.. autoclass:: piq.BRISQUELoss
:members:

CLIP-IQA
''''''''
.. autoclass:: piq.CLIPIQA
:members:


Feature Metrics
---------------
Inseption Score (IS)
Expand Down
5 changes: 5 additions & 0 deletions examples/feature_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ def main():
msid: torch.Tensor = piq.MSID()(x_features, y_features)
print(f"MSID: {msid:0.4f}")

# Use PR class to compute Improved Precision and Recall score from image features,
# pre-extracted from some feature extractor network:
pr: tuple = piq.PR()(x_features, y_features)
print(f"Improved Precision and Recall: {pr[0]:0.4f} {pr[1]:0.4f}")


if __name__ == '__main__':
main()
10 changes: 10 additions & 0 deletions examples/image_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def main():
brisque_loss: torch.Tensor = piq.BRISQUELoss(data_range=1., reduction='none')(x)
print(f"BRISQUE index: {brisque_index.item():0.4f}, loss: {brisque_loss.item():0.4f}")

# To compute CLIP-IQA score as a measure, use PyTorch module from the library
clip_iqa_index: torch.Tensor = piq.CLIPIQA(data_range=1.).to(x.device)(x)
print(f"CLIP-IQA: {clip_iqa_index.item():0.4f}")

# To compute Content score as a loss function, use corresponding PyTorch module
# By default VGG16 model is used, but any feature extractor model is supported.
# Don't forget to adjust layers names accordingly. Features from different layers can be weighted differently.
Expand Down Expand Up @@ -63,6 +67,12 @@ def main():
haarpsi_loss: torch.Tensor = piq.HaarPSILoss(data_range=1., reduction='none')(x, y)
print(f"HaarPSI index: {haarpsi_index.item():0.4f}, loss: {haarpsi_loss.item():0.4f}")

# To compute IW-SSIM index as a measure, use lower case function from the library:
iw_ssim_index: torch.Tensor = piq.information_weighted_ssim(x, y, data_range=1.)
# In order to use IW-SSIM as a loss function, use corresponding PyTorch module:
iw_ssim_loss = piq.InformationWeightedSSIMLoss(data_range=1., reduction='none').to(x.device)(x, y)
print(f"IW-SSIM index: {iw_ssim_index.item():0.4f}, loss: {iw_ssim_loss.item():0.4f}")

# To compute LPIPS as a loss function, use corresponding PyTorch module
lpips_loss: torch.Tensor = piq.LPIPS(reduction='none')(x, y)
print(f"LPIPS: {lpips_loss.item():0.4f}")
Expand Down
32 changes: 18 additions & 14 deletions piq/clip_iqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

class CLIPIQA(_Loss):
r"""Creates a criterion that measures image quality based on a general notion of text-to-image similarity
learned by the CLIP[1] model during its large-scale pre-training on a large dataset with paired texts and images.
learned by the CLIP model (Radford et al., 2021) during its large-scale pre-training on a large dataset
with paired texts and images.
The method is based on the idea that two antonyms ("Good photo" and "Bad photo") can be used as anchors in the
text embedding space representing good and bad images in terms of their image quality.
Expand All @@ -35,13 +36,14 @@ class CLIPIQA(_Loss):
1. Compute the image embedding of the image of interest using the pre-trained CLIP model;
2. Compute the text embeddings of the selected anchor antonyms;
3. Compute the angle (cosine similarity) between the image embedding (1) and both text embeddings (2);
4. Compute the Softmax of cosine similarities (3) -> CLIP-IQA[2] score.
4. Compute the Softmax of cosine similarities (3) -> CLIP-IQA score (Wang et al., 2022).
This method is proposed to eliminate the linguistic ambiguity of the naive approach
(using a single prompt, e.g., "Good photo").
This method has an extension called CLIP-IQA+[2] proposed in the same research paper.
It uses the same approach but also fine-tunes the CLIP weights using the CoOp[3] fine-tuning algorithm.
This method has an extension called CLIP-IQA+ proposed in the same research paper.
It uses the same approach but also fine-tunes the CLIP weights using the CoOp
fine-tuning algorithm (Zhou et al., 2022).
Note:
The initial computation of the metric is performed in `float32` and other dtypes (i.e. `float16`, `float64`)
Expand All @@ -50,7 +52,7 @@ class CLIPIQA(_Loss):
Warning:
In order to avoid implicit dtype conversion and normalization of input tensors, they are copied.
Note that it may consume extra memory, which might be noticible on large batch sizes.
Note that it may consume extra memory, which might be noticeable on large batch sizes.
Args:
data_range: Maximum value range of images (usually 1.0 or 255).
Expand All @@ -62,11 +64,13 @@ class CLIPIQA(_Loss):
>>> score = clipiqa(x)
References:
[1] Radford, Alec, et al. "Learning transferable visual models from natural language supervision."
Radford, Alec, et al. "Learning transferable visual models from natural language supervision."
International conference on machine learning. PMLR, 2021.
[2] Wang, Jianyi, Kelvin CK Chan, and Chen Change Loy. "Exploring CLIP for Assessing the Look
Wang, Jianyi, Kelvin CK Chan, and Chen Change Loy. "Exploring CLIP for Assessing the Look
and Feel of Images." arXiv preprint arXiv:2207.12396 (2022).
[3] Zhou, Kaiyang, et al. "Learning to prompt for vision-language models." International
Zhou, Kaiyang, et al. "Learning to prompt for vision-language models." International
Journal of Computer Vision 130.9 (2022): 2337-2348.
"""
def __init__(self, data_range: Union[float, int] = 1.) -> None:
Expand All @@ -75,7 +79,7 @@ def __init__(self, data_range: Union[float, int] = 1.) -> None:
self.feature_extractor = clip.load().eval()
for param in self.feature_extractor.parameters():
param.requires_grad = False

# Pre-computed tokens for prompt pairs: "Good photo.", "Bad photo.".
tokens = download_tensor(TOKENS_URL, os.path.expanduser("~/.cache/clip"))

Expand All @@ -96,16 +100,16 @@ def forward(self, x_input: torch.Tensor) -> torch.Tensor:
r"""Computation of CLIP-IQA metric for a given image :math:`x`.
Args:
x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(C, H, W)`.
x: An input tensor. Shape :math:`(N, C, H, W)`.
The metric is designed in such a way that it expects:
- 3D or 4D PyTorch tensors;
- These tensors are have any ranges of values between 0 and 255;
- These tensros have channels first format.
- A 4D PyTorch tensor;
- The tensor might have flexible data ranges depending on `data_range` value;
- The tensor must have channels first format.
Returns:
The value of CLI-IQA score in [0, 1] range.
"""
_validate_input([x_input], dim_range=(3, 4), data_range=(0., 255.), check_for_channels_first=True)
_validate_input([x_input], dim_range=(4, 4), data_range=(0., 255.), check_for_channels_first=True)

x = x_input.clone()
x = x.float() / self.data_range
Expand Down
32 changes: 13 additions & 19 deletions tests/test_clip_iqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,38 +78,32 @@ def test_clip_iqa_input_dtype_does_not_change(clipiqa: _Loss, x_rgb: torch.Tenso

def test_clip_iqa_dims_work(clipiqa: _Loss, device: str) -> None:
clipiqa = clipiqa.to(device)
x_3dims = [torch.rand((3, 96, 96)), torch.rand((3, 128, 128)), torch.rand((3, 160, 160))]
for x in x_3dims:
clipiqa(x.to(device))

x_4dims = [torch.rand((3, 3, 96, 96)), torch.rand((4, 3, 128, 128)), torch.rand((5, 3, 160, 160))]
for x in x_4dims:
clipiqa(x.to(device))


def test_clip_iqa_results_equal_for_3_and_4_dims(clipiqa: _Loss, device: str) -> None:
clipiqa = clipiqa.to(device)
x = torch.rand((3, 128, 128))
x_copy = x[None]
x_result = clipiqa(x.to(device))
x_copy_result = clipiqa(x_copy.to(device))
assert torch.isclose(x_result, x_copy_result, rtol=1e-2), \
f'Expected values to be equal, got {x_result} and {x_copy_result}'


def test_clip_iqa_dims_does_not_work(clipiqa: _Loss, device: str) -> None:
clipiqa = clipiqa.to(device)
x_2dims = [torch.rand((96, 96)), torch.rand((128, 128)), torch.rand((160, 160))]
with pytest.raises(AssertionError):
for x in x_2dims:
for x in x_2dims:
with pytest.raises(AssertionError):
clipiqa(x.to(device))

x_1dims = [torch.rand((96)), torch.rand((128)), torch.rand((160))]
with pytest.raises(AssertionError):
for x in x_1dims:

for x in x_1dims:
with pytest.raises(AssertionError):
clipiqa(x.to(device))

x_3dims = [torch.rand((3, 96, 96)), torch.rand((3, 128, 128)), torch.rand((3, 160, 160))]
for x in x_3dims:
with pytest.raises(AssertionError):
clipiqa(x.to(device))

x_5dims = [torch.rand((1, 3, 3, 96, 96)), torch.rand((2, 4, 3, 128, 128)), torch.rand((1, 5, 3, 160, 160))]
with pytest.raises(AssertionError):
for x in x_5dims:

for x in x_5dims:
with pytest.raises(AssertionError):
clipiqa(x.to(device))

0 comments on commit 81ac1e0

Please sign in to comment.