Custom loss SSIM

hi ,

I am trying to build a custom loss function for a neural network where my output is an image.
I looked into it and I found about the SSIM loss.

from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from kornia.filters import get_gaussian_kernel2d

[docs]class SSIM(nn.Module):
r"""Creates a criterion that measures the Structural Similarity (SSIM)
index between each element in the input x and target y.

The index can be described as:

.. math::

  \text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)}

  - :math:`c_1=(k_1 L)^2` and :math:`c_2=(k_2 L)^2` are two variables to
    stabilize the division with weak denominator.
  - :math:`L` is the dynamic range of the pixel-values (typically this is
    :math:`2^{\#\text{bits per pixel}}-1`).

the loss, or the Structural dissimilarity (DSSIM) can be finally described

.. math::

  \text{loss}(x, y) = \frac{1 - \text{SSIM}(x, y)}{2}

    window_size (int): the size of the kernel.
    max_val (float): the dynamic range of the images. Default: 1.
    reduction (str, optional): Specifies the reduction to apply to the
     output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
     'mean': the sum of the output will be divided by the number of elements
     in the output, 'sum': the output will be summed. Default: 'none'.

    Tensor: the ssim index.

    - Input: :math:`(B, C, H, W)`
    - Target :math:`(B, C, H, W)`
    - Output: scale, if reduction is 'none', then :math:`(B, C, H, W)`


    >>> input1 = torch.rand(1, 4, 5, 5)
    >>> input2 = torch.rand(1, 4, 5, 5)
    >>> ssim = kornia.losses.SSIM(5, reduction='none')
    >>> loss = ssim(input1, input2)  # 1x4x5x5

def __init__(
        window_size: int,
        reduction: str = 'none',
        max_val: float = 1.0) -> None:
    super(SSIM, self).__init__()
    self.window_size: int = window_size
    self.max_val: float = max_val
    self.reduction: str = reduction

    self.window: torch.Tensor = get_gaussian_kernel2d(
        (window_size, window_size), (1.5, 1.5))
    self.padding: int = self.compute_zero_padding(window_size)

    self.C1: float = (0.01 * self.max_val) ** 2
    self.C2: float = (0.03 * self.max_val) ** 2

def compute_zero_padding(kernel_size: int) -> int:
    """Computes zero padding."""
    return (kernel_size - 1) // 2

def filter2D(
        input: torch.Tensor,
        kernel: torch.Tensor,
        channel: int) -> torch.Tensor:
    return F.conv2d(input, kernel, padding=self.padding, groups=channel)

def forward(  # type: ignore
        img1: torch.Tensor,
        img2: torch.Tensor) -> torch.Tensor:
    if not torch.is_tensor(img1):
        raise TypeError("Input img1 type is not a torch.Tensor. Got {}"
    if not torch.is_tensor(img2):
        raise TypeError("Input img2 type is not a torch.Tensor. Got {}"
    if not len(img1.shape) == 4:
        raise ValueError("Invalid img1 shape, we expect BxCxHxW. Got: {}"
    if not len(img2.shape) == 4:
        raise ValueError("Invalid img2 shape, we expect BxCxHxW. Got: {}"
    if not img1.shape == img2.shape:
        raise ValueError("img1 and img2 shapes must be the same. Got: {}"
                         .format(img1.shape, img2.shape))
    if not img1.device == img2.device:
        raise ValueError("img1 and img2 must be in the same device. Got: {}"
                         .format(img1.device, img2.device))
    if not img1.dtype == img2.dtype:
        raise ValueError("img1 and img2 must be in the same dtype. Got: {}"
                         .format(img1.dtype, img2.dtype))
    # prepare kernel
    b, c, h, w = img1.shape
    tmp_kernel: torch.Tensor =
    kernel: torch.Tensor = tmp_kernel.repeat(c, 1, 1, 1)

    # compute local mean per channel
    mu1: torch.Tensor = self.filter2D(img1, kernel, c)
    mu2: torch.Tensor = self.filter2D(img2, kernel, c)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    # compute local sigma per channel
    sigma1_sq = self.filter2D(img1 * img1, kernel, c) - mu1_sq
    sigma2_sq = self.filter2D(img2 * img2, kernel, c) - mu2_sq
    sigma12 = self.filter2D(img1 * img2, kernel, c) - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + self.C1) * (2 * sigma12 + self.C2)) / \
        ((mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2))

    loss = torch.clamp(torch.tensor(1.) - ssim_map, min=0, max=1) / 2.

    if self.reduction == 'mean':
        loss = torch.mean(loss)
    elif self.reduction == 'sum':
        loss = torch.sum(loss)
    elif self.reduction == 'none':
    return loss


functional interface


[docs]def ssim(
img1: torch.Tensor,
img2: torch.Tensor,
window_size: int,
reduction: str = ‘none’,
max_val: float = 1.0) -> torch.Tensor:
r"""Function that measures the Structural Similarity (SSIM) index between
each element in the input x and target y.

See :class:`~kornia.losses.SSIM` for details.
return SSIM(window_size, reduction, max_val)(img1, img2)

My questions is while training my NN when defining the Loss_function should i do it like this :
loss_function = ssim(the function above) or Loss_function = SSIM(the class defined previously)

also should I define a def backward or does pytorch does it automatically when calling loss.backward()?

You should create an object of SSIM and call it afterwards via:

criterion = SSIM()
loss = criterion(output, target)

If you are using only PyTorch methods in the forward method, you don’t need to define the backward manually and Autograd will create it for you.

PS: you can post code snippets by wrapping them into three backticks ```, which makes debugging easier. :wink: