# Custom loss SSIM

hi ,

I am trying to build a custom loss function for a neural network where my output is an image.
I looked into it and I found about the SSIM loss.

from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from kornia.filters import get_gaussian_kernel2d

[docs]class SSIM(nn.Module):
r"""Creates a criterion that measures the Structural Similarity (SSIM)
index between each element in the input x and target y.

The index can be described as:

.. math::

\text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)}
{(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)}

where:
- :math:c_1=(k_1 L)^2 and :math:c_2=(k_2 L)^2 are two variables to
stabilize the division with weak denominator.
- :math:L is the dynamic range of the pixel-values (typically this is
:math:2^{\#\text{bits per pixel}}-1).

the loss, or the Structural dissimilarity (DSSIM) can be finally described
as:

.. math::

\text{loss}(x, y) = \frac{1 - \text{SSIM}(x, y)}{2}

Arguments:
window_size (int): the size of the kernel.
max_val (float): the dynamic range of the images. Default: 1.
reduction (str, optional): Specifies the reduction to apply to the
output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
'mean': the sum of the output will be divided by the number of elements
in the output, 'sum': the output will be summed. Default: 'none'.

Returns:
Tensor: the ssim index.

Shape:
- Input: :math:(B, C, H, W)
- Target :math:(B, C, H, W)
- Output: scale, if reduction is 'none', then :math:(B, C, H, W)

Examples::

>>> input1 = torch.rand(1, 4, 5, 5)
>>> input2 = torch.rand(1, 4, 5, 5)
>>> ssim = kornia.losses.SSIM(5, reduction='none')
>>> loss = ssim(input1, input2)  # 1x4x5x5
"""

def __init__(
self,
window_size: int,
reduction: str = 'none',
max_val: float = 1.0) -> None:
super(SSIM, self).__init__()
self.window_size: int = window_size
self.max_val: float = max_val
self.reduction: str = reduction

self.window: torch.Tensor = get_gaussian_kernel2d(
(window_size, window_size), (1.5, 1.5))

self.C1: float = (0.01 * self.max_val) ** 2
self.C2: float = (0.03 * self.max_val) ** 2

@staticmethod
def compute_zero_padding(kernel_size: int) -> int:
return (kernel_size - 1) // 2

def filter2D(
self,
input: torch.Tensor,
kernel: torch.Tensor,
channel: int) -> torch.Tensor:

def forward(  # type: ignore
self,
img1: torch.Tensor,
img2: torch.Tensor) -> torch.Tensor:
if not torch.is_tensor(img1):
raise TypeError("Input img1 type is not a torch.Tensor. Got {}"
.format(type(img1)))
if not torch.is_tensor(img2):
raise TypeError("Input img2 type is not a torch.Tensor. Got {}"
.format(type(img2)))
if not len(img1.shape) == 4:
raise ValueError("Invalid img1 shape, we expect BxCxHxW. Got: {}"
.format(img1.shape))
if not len(img2.shape) == 4:
raise ValueError("Invalid img2 shape, we expect BxCxHxW. Got: {}"
.format(img2.shape))
if not img1.shape == img2.shape:
raise ValueError("img1 and img2 shapes must be the same. Got: {}"
.format(img1.shape, img2.shape))
if not img1.device == img2.device:
raise ValueError("img1 and img2 must be in the same device. Got: {}"
.format(img1.device, img2.device))
if not img1.dtype == img2.dtype:
raise ValueError("img1 and img2 must be in the same dtype. Got: {}"
.format(img1.dtype, img2.dtype))
# prepare kernel
b, c, h, w = img1.shape
tmp_kernel: torch.Tensor = self.window.to(img1.device).to(img1.dtype)
kernel: torch.Tensor = tmp_kernel.repeat(c, 1, 1, 1)

# compute local mean per channel
mu1: torch.Tensor = self.filter2D(img1, kernel, c)
mu2: torch.Tensor = self.filter2D(img2, kernel, c)

mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2

# compute local sigma per channel
sigma1_sq = self.filter2D(img1 * img1, kernel, c) - mu1_sq
sigma2_sq = self.filter2D(img2 * img2, kernel, c) - mu2_sq
sigma12 = self.filter2D(img1 * img2, kernel, c) - mu1_mu2

ssim_map = ((2 * mu1_mu2 + self.C1) * (2 * sigma12 + self.C2)) / \
((mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2))

loss = torch.clamp(torch.tensor(1.) - ssim_map, min=0, max=1) / 2.

if self.reduction == 'mean':
loss = torch.mean(loss)
elif self.reduction == 'sum':
loss = torch.sum(loss)
elif self.reduction == 'none':
pass
return loss


######################

# functional interface

######################

[docs]def ssim(
img1: torch.Tensor,
img2: torch.Tensor,
window_size: int,
reduction: str = ‘none’,
max_val: float = 1.0) -> torch.Tensor:
r"""Function that measures the Structural Similarity (SSIM) index between
each element in the input x and target y.

See :class:~kornia.losses.SSIM for details.
"""
return SSIM(window_size, reduction, max_val)(img1, img2)


My questions is while training my NN when defining the Loss_function should i do it like this :
loss_function = ssim(the function above) or Loss_function = SSIM(the class defined previously)

also should I define a def backward or does pytorch does it automatically when calling loss.backward()?

You should create an object of SSIM and call it afterwards via:

criterion = SSIM()
loss = criterion(output, target)


If you are using only PyTorch methods in the forward method, you don’t need to define the backward manually and Autograd will create it for you.

PS: you can post code snippets by wrapping them into three backticks `, which makes debugging easier. 