Analytical SSIM derivative differs from autograd computation

Hi everyone,

I am attempting to compute the analytical derivative of SSIM, which I’ve named dL_ssim_dimg1. However, I’ve noticed a discrepancy between the analytical gradient and the gradient computed by autograd. The maximum element-wise error across all iterations is approximately (1 \times 10^{-4}) and remains fairly consistent. I’ve reviewed my computation multiple times, and I believe it’s accurate. A detailed derivation can be found here.

When I compare this analytical gradient with other autograd computations, the results are consistent. For example, the gradient for mu1.mean() is nearly identical, and the L1 loss also shows minimal error. I’m uncertain if these discrepancies can be solely attributed to floating-point precision issues. To me, the error seems too significant to be caused by this alone.

Is there a method to precisely trace how the loss gradient with respect to img1 is calculated? Btw img2 is the reference image used for optimization, is there a way to track its computation process? I’ve already attempted various approaches like rearranging computations, but the error remains roughly consistent.

Code:

        std::pair<torch::Tensor, torch::Tensor> ssim(const torch::Tensor& img1, const torch::Tensor& img2, const torch::Tensor& window, int window_size, int channel) {

            static const float C1 = 0.01f * 0.01f;
            static const float C2 = 0.03f * 0.03f;

            const auto mu1 = torch::nn::functional::conv2d(img1, window, torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel));
            const auto mu1_sq = mu1.pow(2);
            const auto sigma1_sq = torch::nn::functional::conv2d(img1 * img1, window, torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel)) - mu1_sq;

            const auto mu2 = torch::nn::functional::conv2d(img2, window, torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel));
            const auto mu2_sq = mu2.pow(2);
            const auto sigma2_sq = torch::nn::functional::conv2d(img2 * img2, window, torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel)) - mu2_sq;

            const auto mu1_mu2 = mu1 * mu2;
            const auto sigma12 = torch::nn::functional::conv2d(img1 * img2, window, torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel)) - mu1_mu2;

            const auto l_p = (2.f * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1);
            const auto cs_p = (2.f * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2);
            const auto ssim_map = l_p * cs_p;

            auto grad_mu1_wrt_img1 = torch::nn::functional::conv2d(torch::ones_like(mu1), torch::flip(window, {2, 3}), torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel));
            auto lp_x = 2.f * grad_mu1_wrt_img1 * ((mu2 - mu1 * l_p) / (mu1_sq + mu2_sq + C1));
            auto cs_x = (2.f / (sigma1_sq + sigma2_sq + C2)) * grad_mu1_wrt_img1 * ((img2 - mu2) - cs_p * (img1 - mu1));

            auto dL_ssim_dimg1 = (lp_x * cs_p + l_p * cs_x) / static_cast<float>(img1.size(1) * img1.size(2));

            auto loss = ssim_map.mean();
            loss.backward();

            return {loss, dL_ssim_dimg1};
        }

I have compiled pytorch in debug mode with the intention to exactly reproduce the computation. I guess that the gradients should be identical when computed this way. However, that is a huge task if you are not familiar with the internals.

Hi Janusch!

I wouldn’t rule out the possibility of your discrepancy being caused by floating-point
round-off error (but I’m not saying that it is). In multi-step computations, round-off
error can accumulate or even be amplified.

One easy (and sensible) test for this is to run the computation in both single
(float()) and double (double()) precision. Even if your discrepancy seems
significantly larger than round-off error, if it drops several orders of magnitude
when you switch from single to double precision, round-off error is likely the cause.

Best.

K. Frank

Hi KFrank,
your hint already has helped me. I changed the precision, but nothing happend. So I investigated further by comparing partial gradient and discovered difference. So it is not a numerical issue.

The lp_x computation is now fixed

                auto grad_u_wrt_mu1 = 2. * mu2;
                auto grad_v_wrt_mu1 = 2. * mu1;
                auto grad_lp_wrt_mu1 = (grad_u_wrt_mu1 * (mu1_sq + mu2_sq + C1) - (2. * mu1_mu2 + C1) * grad_v_wrt_mu1) / (mu1_sq + mu2_sq + C1).pow(2);
                auto lp_x = torch::nn::functional::conv2d(grad_lp_wrt_mu1, torch::flip(window, {2, 3}), torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel));

This gives me the identical result as autograd.
But cs_x seems still not entirely correct, which currently I am unable to figure out

                auto grad_direct = img1 * torch::nn::functional::conv2d(img1 * img2, torch::flip(window, {2, 3}), torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel));
                auto grad_mu1 = torch::nn::functional::conv2d(torch::ones_like(mu1), torch::flip(window, {2, 3}), torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel));
                auto grad_mu1_mu2_contribution = mu2 * grad_mu1;
                auto grad_part1 = 2. * torch::nn::functional::conv2d(img1, torch::flip(window, {2, 3}), torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel));
                auto grad_part2 = 2. * mu1 * torch::nn::functional::conv2d(torch::ones_like(mu1), torch::flip(window, {2, 3}), torch::nn::functional::Conv2dFuncOptions().padding(window_size / 2).groups(channel));

                auto grad = grad_part1 - grad_part2;
                auto u_prime = grad_direct - grad_mu1_mu2_contribution;
                auto v_prime = grad_part1 - grad_part2;
                auto cs_x = (u_prime * (sigma1_sq + sigma2_sq + C2) - (2. * sigma12 + C2) * v_prime) / torch::pow(sigma1_sq + sigma2_sq + C2, 2);

Is there a way to trace the autograd computation, to see what exactly is chained together?
Best,
Janusch

Hi Janusch!

One approach is to use explicit intermediate tensors (as you seem to be doing)
and flag them with .retain_grad(). Here’s a simple example:

>>> import torch
>>> torch.__version__
'2.0.1'
>>> p = torch.tensor (1.0, requires_grad = True)
>>> t1 = 2.0 * p
>>> t1.retain_grad()
>>> t2 = 3.0 * t1
>>> t2.retain_grad()
>>> t2.backward()
>>> p.grad
tensor(6.)
>>> t1.grad
tensor(3.)
>>> t2.grad
tensor(1.)

(Note that this example uses pytorch’s python interface. I assume that it is also
possible with c++, but I don’t know how to do it that way.)

Best.

K. Frank

Hi K.Frank,

I went again over all computation and compared against img.grad.
I verified every step. This breaks down as soon as I compute l_px, which is the derivative of l_p with applying the quotient rule. I don’t get why it diverges at this point.
This is now a simplified example with pytorch

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    img = torch.tensor([
    [
        [0.2, 0.3, 0.3, 0.4],
        [0.5, 0.3, 0.4, 0.2],
        [0.2, 0.3, 0.4, 0.4],
        [0.5, 0.4, 0.5, 0.2]
    ]], dtype=torch.float64, requires_grad=True)

    img_ref = torch.tensor([
    [
        [0.1, 0.2, 0.23, 0.4],
        [0.3, 0.21, 0.43, 0.2],
        [0.1, 0.1, 0.44, 0.14],
        [0.3345, 0.412, 0.53, 0.132]
    ]], dtype=torch.float64, requires_grad=True)

    print(img.shape)
    channels = 1
    win_size = 3
    win = create_window(win_size, channels).to(torch.float64)
    # Recompute mu1_check
    mu1 = F.conv2d(img, win, padding=win_size // 2, groups=channels)
    mu1.retain_grad()
    mu1_sq = mu1.pow(2)
    mu1_sq.retain_grad()
    mu2 = F.conv2d(img_ref, win, padding=win_size // 2, groups=channels)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    a = (2. * mu1_mu2 + C1)
    b = (mu1_sq + mu2_sq + C1)
    l_p = a / b 

    # Define a hypothetical scalar loss as the sum of all elements in mu1_check
    hypothetical_loss = l_p.mean()

    # Backpropagate this loss to compute the gradient with respect to img
    hypothetical_loss.backward()

    # The gradient of mu1 with respect to img using autograd
    with torch.no_grad():
        # Manually compute the gradient of mu1 with respect to img using convolution with flipped kernel
        flipped_win = torch.flip(win, [2, 3])
        d_mu1 = F.conv2d(torch.ones_like(mu1), flipped_win, padding=win_size // 2, groups=channels)

        d_mu1_sq = F.conv2d(2 * mu1, flipped_win, padding=win_size // 2, groups=channels) 

        d_mu1_mu2 = F.conv2d(mu2, flipped_win, padding=win_size // 2, groups=channels)
        # Compute the difference between the manually computed gradient and the gradient obtained from autograd for img
        d_a = 2 * d_mu1_mu2
        d_b = d_mu1_sq
        l_px = (d_a * b - a * d_b) / (b ** 2)
        max_diff_mu1_wrt_img_check = torch.abs((l_px / img.numel()) - img.grad).max().item()
        print(l_px / img.numel())
        print(img.grad)
        print(max_diff_mu1_wrt_img_check)
 
def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

I have provided the create_window function for completness.

Best,
Janusch

Hi Janusch!

I’m not sure I follow what you are trying to do.

I speculate that you are trying to compute the final img.grad by performing
the chain rule “by hand” as you perform your forward-pass computation.

This can be done (see, for example, pytorch’s forward-mode autodifferention),
but you have to move forward from step to step with a *jacobian-vector" product,
where the “vector” is not a gradient with respect to the input variables.

This may be where your misconception starts. d_mu1 is not, technically speaking,
the gradient of mu1 with respect to img. To be precise, it is the gradient of
mu1.sum() with respect to img. The term gradient refers to the vector of partial
derivatives of a scalar function (with respect to a vector argument). So you do
have a gradient here, but it is the gradient of the scalar mu1.sum().

For the forward-pass scheme I think you have in mind, you would need the
jacobian – the matrix of partial derivatives of a vector-valued function with respect
to its vector argument.

Note that in your example both img and mu1 have shape [1, 4, 4], so the
jacobian of mu1 with respect to img would have shape [1, 4, 4, 1, 4, 4] (or
shape [16, 16] if you want express it as an ordinary matrix). d_mul has shape
[1, 4, 4], which is right for the gradient, but not for the full jacobian.

At this point d_a and d_b are the gradients with respect to img of a.sum() and
b_sum(), respectively (as you imply that you’ve checked).

My speculation is that you are applying the quotient rule to l_p = a / b to compute
its gradient (or, more precisely, the gradient of l_p.sum()) with respect to img
(and then testing that result against img.grad). However, you can’t use the
gradients d_a and d_b to compute the gradient of l_p.sum() – they just don’t
contain enough information.

To do what you want, you need the jacobians of a and b with respect to img.
One way to understand that the gradients might lack sufficient information to
compute the subsequent gradient is to realize that you can have two vector
functions that have two different jacobians (when evaluated at some specific
value of the functions’ vector input arguments), but for which the gradients of
their scalar sums are equal. That is to say, the gradients can’t and don’t encode
all of the information in the jacobians.

These points are illustrated in a simpler setting (using the product rule) in the
following script:

import torch
print (torch.__version__)

# functions designed to have the same gradients but different jacobians for a specific x

def fa (x):
    return x**2

def fb (x):
    return  torch.arange (-1.5, 3.) * x * x.mean()

def fa_sum (x):   # scalar sum of fa
    return  fa (x).sum()

def fb_sum (x):   # scalar sum of fb
    return  fb (x).sum()

gfa = torch.func.grad (fa_sum)   # gradient of the scalar sum
gfb = torch.func.grad (fb_sum)   # gradient of the scalar sum
jfa = torch.func.jacrev (fa)     # jacobian of the vector-valued function
jfb = torch.func.jacrev (fb)     # jacobian of the vector-valued function

x = torch.arange (5.)            # value of x for which the gradients will be the same
print ('x:', x)

atol = 1.e-6                     # atol for allclose()

print ('do gradients match? :', torch.allclose (gfa (x), gfb (x), atol = atol))   # gradients are the same

print ('do jacobians match? :', torch.allclose (jfa (x), jfb (x), atol = atol))   # jacobians are different

print ('do fa gradient and jacobian-sum match? :', torch.allclose (gfa (x), jfa (x).sum (dim = 0), atol = atol))   # gradient is the column-sum of the jacobian
print ('do fb gradient and jacobian-sum match? :', torch.allclose (gfb (x), jfb (x).sum (dim = 0), atol = atol))   # gradient is the column-sum of the jacobian

# use .backward() to get gradient of (fa (x) * fb (x)).sum()
x.requires_grad = True
(fa (x) * fb (x)).sum().backward()
grad = x.grad

# can we use the gradients of fa_sum and fb_sum to get the gradient of (fa (x) * fb (x)).sum()?
d_fa_fb = gfa (x) * fb (x) + fa (x) * gfb (x)   # ??? -- does "product rule" work with just gradients?
print ('does d_fa_fb = grad? :', torch.allclose (d_fa_fb, grad, atol = atol))   # no -- the gradients do not contain enough information

# however, the jacobians do contain enough information
grad_fa_fb_sum = fb (x) @ jfa (x) + fa (x) @ jfb (x)
print ('does grad_fa_fb_sum = grad? :', torch.allclose (grad_fa_fb_sum, grad, atol = atol))   # product rule using jacobians works

And here is the script’s output:

2.0.1
x: tensor([0., 1., 2., 3., 4.])
do gradients match? : True
do jacobians match? : False
do fa gradient and jacobian-sum match? : True
do fb gradient and jacobian-sum match? : True
does d_fa_fb = grad? : False
does grad_fa_fb_sum = grad? : True

Best.

K. Frank