NoneType for backward gradient function

I have tried to implement a “complicated” (for me…) loss function (a porting from MATLAB repo).
Strangely it compiles, but when I try to call the .backward() function it returns to me a NoneType.
I’m analyzing the code but, even if there are many things that I’m not understanding, I don’t find the real cause of the problem.
Could someone help me in this desperate effort?

The code to reproduce the error is:

import torch
import math
import numpy as np
from torch import nn


def normalize_block(im):

    m = torch.mean(im)
    s = torch.std(im)

    if s == 0:
        s = 1e-7

    y = ((im-m)/s)+1

    return y, m, s

def onion_mult(onion1, onion2):
    bs, N = onion1.size()

    if N>1:
        L = int(N/2)
        a = onion1[:, :L]
        b = onion1[:, L:]
        b = torch.cat((torch.unsqueeze(b[:,0], 1), -b[:,1:]), dim=1)
        c = onion2[:, :L]
        d = onion2[:, L:]
        d = torch.cat((torch.unsqueeze(d[:, 0], 1), -d[:, 1:]), dim=1)

        if N == 2:
            ris = torch.cat(((a*c)-(d*b), (a*d)+(c*b)), dim=1)
        else:
            ris1 = onion_mult(a, c)
            ris2 = onion_mult(d, torch.cat((torch.unsqueeze(b[:, 0], 1), -b[:, 1:]), dim=1))
            ris3 = onion_mult(torch.cat((torch.unsqueeze(a[:, 0], 1), -a[:, 1:]), dim=1), d)
            ris4 = onion_mult(c, b)

            aux1 = ris1 - ris2
            aux2 = ris3 + ris4
            ris = torch.cat((aux1, aux2), dim=1)
    else:
        ris = onion1 * onion2

    return ris


def onion_mult2D(onion1, onion2):

    bs, dim3, _, _ = onion1.size()
    if (dim3 > 1):
        L = int(dim3/2)

        a = onion1[:, 0:L, :, :]
        b = onion1[:, L:, :, :]
        b = torch.cat((torch.unsqueeze(b[:,0,:,:], 1), -b[:,1:, :, :],), dim=1)
        c = onion2[:, 0:L, :, :]
        d = onion2[:, L:, :, :]
        d = torch.cat((torch.unsqueeze(d[:, 0, :, :], 1), -d[:, 1:, :, :],), dim=1)

        if dim3 == 2:
            ris = torch.cat(((a*c)-(d*b), (a*d)+(c*b)),dim=1)
        else:
            ris1 = onion_mult2D(a,c)
            ris2 = onion_mult2D(d, torch.cat((torch.unsqueeze(b[:,0,:,:], 1), -b[:, 1:, :, :]),dim=1))
            ris3 = onion_mult2D(torch.cat((torch.unsqueeze(a[:,0,:,],1), -a[:, 1:, :, :]), dim=1), d)
            ris4 = onion_mult2D(c,b)

            aux1 = ris1-ris2
            aux2 = ris3+ris4

            ris = torch.cat((aux1, aux2), dim=1)
    else:
        ris = onion1 * onion2

    return ris


def onions_quality(im1, im2, size, device):

    im1 = im1.type(torch.double).to(device)
    im2 = im2.type(torch.double).to(device)
    im2 = torch.cat((torch.unsqueeze(im2[:, 0, :, :], 1), -im2[:, 1:, :,:]), dim=1)
    batch_size, dim3, _, _ = im1.size()

    for bs in range(batch_size):
        for i in range(dim3):
            a1, s, t = normalize_block(im1[bs,i,:,:])
            im1[bs,i,:,:] = a1
            if s == 0:
                if i == 0:
                    im2[bs,i,:,:] = im2[bs,i,:,:]-s+1
                else:
                    im2[bs, i, :, :] = -(-im2[bs, i, :, :] - s + 1)
            else:
                if i == 0:
                    im2[bs, i, :, :] = ((im2[bs, i, :, :] - s)/t)+1
                else:
                    im2[bs, i, :, :] = -(((-im2[bs, i, :, :] - s)/t)+1)


    m1 = torch.mean(im1, dim=(2, 3))
    m2 = torch.mean(im2, dim=(2, 3))
    mod_q1m = torch.sqrt(torch.sum(m1**2, dim=1))
    mod_q2m = torch.sqrt(torch.sum(m2**2, dim=1))

    mod_q1 = torch.sqrt(torch.sum(im1 ** 2, dim=1))
    mod_q2 = torch.sqrt(torch.sum(im2 ** 2, dim=1))

    term2 = mod_q1m * mod_q2m
    term4 = mod_q1m**2 + mod_q2m**2
    temp = [size ** 2 / (size**2 - 1)] * batch_size
    temp = torch.from_numpy(np.asarray(temp)).to(device)
    int1 = torch.clone(temp)
    int2 = torch.clone(temp)
    int3 = torch.clone(temp)
    int1 = int1 * torch.mean(mod_q1**2)
    int2 = int2 * torch.mean(mod_q2**2)
    int3 = int3 * (mod_q1m **2 + mod_q2m ** 2)
    term3 = int1 + int2 - int3

    mean_bias = 2*term2/term4
    if term3 == 0:
        q = torch.zeros((batch_size, 1, 1, dim3), device=device, requires_grad=True)
        q[:,:,:,dim3-1] = mean_bias
    else:
        cbm = 2/term3
        qu = onion_mult2D(im1, im2)
        qm = onion_mult(m1, m2)
        #qv = torch.zeros((batch_size, dim3), device=device, requires_grad=True)

        #for bs in range(batch_size):
            #for i in range(dim3):
                #qv[:, i] = (size**2)/(size**2 - 1) * torch.mean(qu[:, i, :,:], dim=(2,3))

        qv = (size ** 2) / (size ** 2 - 1) * torch.mean(qu, dim=(-2, -1))

        q = qv - temp*qm
        q = q*mean_bias*cbm

    return q

class complicated_loss (nn.Module):
    def __init__(self, device, Q_block_size=32, Q_shift=32):
        super(complicated_loss, self).__init__()

        self.Q_block_size = Q_block_size
        self.Q_shift = Q_shift
        self.device = device



    def forward(self, outputs, labels):

        bs, dim3, dim1, dim2 = labels.size()
        _, _, ddim1, ddim2 = outputs.size()

        stepx = math.ceil(dim1/self.Q_shift)
        stepy = math.ceil(dim2/self.Q_shift)

        if stepy <= 0:
            stepy = 1
            stepx = 1

        est1 = (stepx - 1)*self.Q_shift+self.Q_block_size-dim1
        est2 = (stepy - 1)*self.Q_shift+self.Q_block_size-dim2

        if (est1 != 0)+(est2 != 0) > 0:

            padding = torch.nn.ReflectionPad2d((0, est1, 0, est2))

            reference = padding(labels)
            fused = padding(outputs)

            outputs = fused.type(torch.int16).to(self.device)
            labels = reference.type(torch.int16).to(self.device)

        bs, dim3, dim1, dim2 = labels.size()

        if(math.ceil(math.log2(dim3)) - math.log2(dim3) != 0):
            exp_difference = 2 ** (torch.ceil(torch.log2(dim3))) - dim3
            diff = torch.zeros((bs, exp_difference, dim1, dim2), device=self.device, requires_grad=True).type(torch.int16)
            labels = torch.cat((labels, diff), dim=1)
            outputs = torch.cat((outputs, diff), dim=1)

        bs, dim3, dim1, dim2 = labels.size()

        values = torch.zeros((bs, dim3, stepx, stepy), device=self.device, requires_grad=True)

        for j in range(stepx):
            for i in range(stepy):
                o = onions_quality(labels[:, :, j * self.Q_shift:j * self.Q_shift + self.Q_block_size, i * self.Q_shift : i * self.Q_shift + self.Q_block_size], outputs[:, :, j * self.Q_shift:j * self.Q_shift + self.Q_block_size, i * self.Q_shift : i * self.Q_shift + self.Q_block_size], self.Q_block_size, self.device)
                values.data[:,:, j,i] = o
        index_map = torch.sqrt(torch.sum(values**2, dim=1))
        index = torch.mean(index_map)

        loss = 1.0 - index

        return loss


if __name__ == '__main__':
    device = torch.device('cpu')
    a = np.arange(256*256)
    a = a.reshape(256,256)
    a = a.astype('float32')
    a = np.expand_dims(a, (0,1))
    a = torch.from_numpy(a)
    b = torch.zeros(a.size())
    a.requires_grad = True

    criterion = complicated_loss(device)

    loss = criterion(a,b)
    f = loss.backward()


The problem is that this loss does not update the weights of the network during the training loop. Where am I wrong?
Thank you!

You are manually creating a new tensor with requires_grad=True inside your module:

values = torch.zeros((bs, dim3, stepx, stepy), device=self.device, requires_grad=True)

which is not attached to any computation graph (and thus also not to the input to the criterion).
Unrelated to this particular issue, but also don’t use the .data attribute to manipulate tensors, as it’s deprecated and can yield unwanted side effects.

Instead of recreating the values tensor, try to append the outputs of onion_quality to e.g. a list and create a tensor via torch.stack():

       values = []

        for j in range(stepx):
            for i in range(stepy):
                o = onions_quality(labels[:, :, j * self.Q_shift:j * self.Q_shift + self.Q_block_size, i * self.Q_shift : i * self.Q_shift + self.Q_block_size], outputs[:, :, j * self.Q_shift:j * self.Q_shift + self.Q_block_size, i * self.Q_shift : i * self.Q_shift + self.Q_block_size], self.Q_block_size, self.device)
                values.append(o)
        values = torch.stack(values)

I haven’t looked through your entire code and while this solves the gradient error (a.grad will show values), the gradient will contain all NaNs, so check if you are e.g. dividing by zero etc.

Thank you so much for your answer.
I’m still troubled with the code. Exploiting the debug there are not strange things, such as zero divisions and also the final value of the loss is correct and coherent.
Could some recursive operations ( onion_mult() and onion_mult2D() functions) be the real problem? There is a way to convert a stack of images in quaternions and to apply the operations more efficiently?

I’ve made slight changes in the code, trying to solve the issue.


import torch
import math
import numpy as np
from torch import nn


def normalize_block(im):

    m = torch.mean(im)
    s = torch.std(im)

    if s == 0:
        s = 1e-7

    y = ((im-m)/s)+1

    return y, m, s

def onion_mult(onion1, onion2):
    bs, N = onion1.size()

    if N>1:
        L = int(N/2)
        a = onion1[:, :L]
        b = onion1[:, L:]
        b = torch.cat((torch.unsqueeze(b[:,0], 1), -b[:,1:]), dim=1)
        c = onion2[:, :L]
        d = onion2[:, L:]
        d = torch.cat((torch.unsqueeze(d[:, 0], 1), -d[:, 1:]), dim=1)

        if N == 2:
            ris = torch.cat(((a*c)-(d*b), (a*d)+(c*b)), dim=1)
        else:
            ris1 = onion_mult(a, c)
            ris2 = onion_mult(d, torch.cat((torch.unsqueeze(b[:, 0], 1), -b[:, 1:]), dim=1))
            ris3 = onion_mult(torch.cat((torch.unsqueeze(a[:, 0], 1), -a[:, 1:]), dim=1), d)
            ris4 = onion_mult(c, b)

            aux1 = ris1 - ris2
            aux2 = ris3 + ris4
            ris = torch.cat((aux1, aux2), dim=1)
    else:
        ris = onion1 * onion2

    return ris


def onion_mult2D(onion1, onion2):

    bs, dim3, _, _ = onion1.size()
    if (dim3 > 1):
        L = int(dim3/2)

        a = onion1[:, 0:L, :, :]
        b = onion1[:, L:, :, :]
        b = torch.cat((torch.unsqueeze(b[:,0,:,:], 1), -b[:,1:, :, :],), dim=1)
        c = onion2[:, 0:L, :, :]
        d = onion2[:, L:, :, :]
        d = torch.cat((torch.unsqueeze(d[:, 0, :, :], 1), -d[:, 1:, :, :],), dim=1)

        if dim3 == 2:
            ris = torch.cat(((a*c)-(d*b), (a*d)+(c*b)),dim=1)
        else:
            ris1 = onion_mult2D(a,c)
            ris2 = onion_mult2D(d, torch.cat((torch.unsqueeze(b[:,0,:,:], 1), -b[:, 1:, :, :]),dim=1))
            ris3 = onion_mult2D(torch.cat((torch.unsqueeze(a[:,0,:,],1), -a[:, 1:, :, :]), dim=1), d)
            ris4 = onion_mult2D(c,b)

            aux1 = ris1-ris2
            aux2 = ris3+ris4

            ris = torch.cat((aux1, aux2), dim=1)
    else:
        ris = onion1 * onion2

    return ris


def onions_quality(im1, im2, size, device):

    im1 = im1.type(torch.double).to(device)
    im2 = im2.type(torch.double).to(device)
    im2 = torch.cat((torch.unsqueeze(im2[:, 0, :, :], 1), -im2[:, 1:, :,:]), dim=1)
    batch_size, dim3, _, _ = im1.size()

    for bs in range(batch_size):
        for i in range(dim3):
            a1, s, t = normalize_block(im1[bs,i,:,:])
            im1[bs,i,:,:] = a1
            if s == 0:
                if i == 0:
                    im2[bs,i,:,:] = im2[bs,i,:,:]-s+1
                else:
                    im2[bs, i, :, :] = -(-im2[bs, i, :, :] - s + 1)
            else:
                if i == 0:
                    im2[bs, i, :, :] = ((im2[bs, i, :, :] - s)/t)+1
                else:
                    im2[bs, i, :, :] = -(((-im2[bs, i, :, :] - s)/t)+1)


    m1 = torch.mean(im1, dim=(2, 3))
    m2 = torch.mean(im2, dim=(2, 3))
    mod_q1m = torch.sqrt(torch.sum(torch.pow(m1,2), dim=1))
    mod_q2m = torch.sqrt(torch.sum(torch.pow(m2,2), dim=1))

    mod_q1 = torch.sqrt(torch.sum(torch.pow(im1, 2), dim=1))
    mod_q2 = torch.sqrt(torch.sum(torch.pow(im2, 2), dim=1))

    term2 = mod_q1m * mod_q2m
    term4 = torch.pow(mod_q1m,2) + torch.pow(mod_q2m,2)
    temp = [(size * size) / (size * size - 1)] * batch_size
    temp = torch.from_numpy(np.asarray(temp)).to(device)
    int1 = torch.clone(temp)
    int2 = torch.clone(temp)
    int3 = torch.clone(temp)
    int1 = int1 * torch.mean(torch.pow(mod_q1,2))
    int2 = int2 * torch.mean(torch.pow(mod_q2,2))
    int3 = int3 * (torch.pow(mod_q1m,2) + torch.pow(mod_q2m,2))
    term3 = int1 + int2 - int3

    mean_bias = 2*term2/term4
    if term3 == 0:
        q = torch.zeros((batch_size, 1, 1, dim3), device=device, requires_grad=False)
        q[:,:,:,dim3-1] = mean_bias
    else:
        cbm = 2/term3
        qu = onion_mult2D(im1, im2)
        qm = onion_mult(m1, m2)
        #qv = torch.zeros((batch_size, dim3), device=device, requires_grad=False)

        #for bs in range(batch_size):
            #for i in range(dim3):
                #qv[:, i] = (size**2)/(size**2 - 1) * torch.mean(qu[:, i, :,:], dim=(2,3))

        qv = (size * size) / (size * size - 1) * torch.mean(qu, dim=(-2, -1))

        q = qv - temp*qm
        q = q*mean_bias*cbm

    return q

class complicated_loss (nn.Module):
    def __init__(self, device, Q_block_size=32, Q_shift=32):
        super(complicated_loss, self).__init__()

        self.Q_block_size = Q_block_size
        self.Q_shift = Q_shift
        self.device = device



    def forward(self, outputs, labels):

        bs, dim3, dim1, dim2 = labels.size()
        _, _, ddim1, ddim2 = outputs.size()

        stepx = math.ceil(dim1/self.Q_shift)
        stepy = math.ceil(dim2/self.Q_shift)

        if stepy <= 0:
            stepy = 1
            stepx = 1

        est1 = (stepx - 1)*self.Q_shift+self.Q_block_size-dim1
        est2 = (stepy - 1)*self.Q_shift+self.Q_block_size-dim2

        if (est1 != 0)+(est2 != 0) > 0:

            padding = torch.nn.ReflectionPad2d((0, est1, 0, est2))

            reference = padding(labels)
            fused = padding(outputs)

            outputs = fused.type(torch.int16).to(self.device)
            labels = reference.type(torch.int16).to(self.device)

        bs, dim3, dim1, dim2 = labels.size()

        if(math.ceil(math.log2(dim3)) - math.log2(dim3) != 0):
            exp_difference = 2 ** (math.ceil(math.log2(dim3))) - dim3
            diff_labels = torch.zeros((bs, exp_difference, dim1, dim2), device=self.device, requires_grad=False).type(torch.int16)
            diff_outputs = torch.zeros((bs, exp_difference, dim1, dim2), device=self.device, requires_grad=True).type(torch.int16)
            labels = torch.cat((labels, diff_labels), dim=1)
            outputs = torch.cat((outputs, diff_outputs), dim=1)

        bs, dim3, dim1, dim2 = labels.size()
        """
        values = torch.zeros((bs, dim3, stepx, stepy), device=self.device, requires_grad=True)

        for j in range(stepx):
            for i in range(stepy):
                o = onions_quality(labels[:, :, j * self.Q_shift:j * self.Q_shift + self.Q_block_size, i * self.Q_shift : i * self.Q_shift + self.Q_block_size], outputs[:, :, j * self.Q_shift:j * self.Q_shift + self.Q_block_size, i * self.Q_shift : i * self.Q_shift + self.Q_block_size], self.Q_block_size, self.device)
                values.data[:,:, j,i] = o
        """
        values = []

        for j in range(stepx):
            for i in range(stepy):
                o = onions_quality(labels[:, :, j * self.Q_shift:j * self.Q_shift + self.Q_block_size,
                                   i * self.Q_shift: i * self.Q_shift + self.Q_block_size],
                                   outputs[:, :, j * self.Q_shift:j * self.Q_shift + self.Q_block_size,
                                   i * self.Q_shift: i * self.Q_shift + self.Q_block_size], self.Q_block_size,
                                   self.device)
                values.append(o)
        values = torch.stack(values)

        index_map = torch.sqrt(torch.sum(torch.pow(values, 2), dim=1))
        #index_map = torch.abs(torch.sum(values, dim=1))
        index = torch.mean(index_map)

        loss = 1.0 - index

        return loss


if __name__ == '__main__':
    device = torch.device('cpu')
    torch.autograd.set_detect_anomaly(True)
    a = np.arange(256*256)
    a = a.reshape(256,256)
    a = a.astype('float32')
    a = np.expand_dims(a, (0,1))
    a = torch.from_numpy(a)
    b = torch.clone(a)
    b[0,0,5:100,5:100] = 0
    a.requires_grad = True
    c = torch.clone(a)

    criterion2 = nn.L1Loss(reduction='mean')

    criterion = complicated_loss(device)

    loss = criterion(a,b)
    loss_2 = criterion2(c,b)
    loss_2.backward()
    loss.backward()

Now I have this error message. I think that simply this function, as it has been written, is not differentiable.

/home/matteo/anaconda3/envs/pytorch_env/lib/python3.9/site-packages/torch/autograd/__init__.py:145: UserWarning: Error detected in PowBackward0. Traceback of forward call that caused the error:
  File "/snap/pycharm-professional/271/plugins/python/helpers/pydev/pydevconsole.py", line 509, in <module>
    pydevconsole.start_client(host, port)
  File "/snap/pycharm-professional/271/plugins/python/helpers/pydev/pydevconsole.py", line 437, in start_client
    process_exec_queue(interpreter)
  File "/snap/pycharm-professional/271/plugins/python/helpers/pydev/pydevconsole.py", line 284, in process_exec_queue
    interpreter.add_exec(code_fragment)
  File "/snap/pycharm-professional/271/plugins/python/helpers/pydev/_pydev_bundle/pydev_code_executor.py", line 108, in add_exec
    more, exception_occurred = self.do_add_exec(code_fragment)
  File "/snap/pycharm-professional/271/plugins/python/helpers/pydev/pydevconsole.py", line 90, in do_add_exec
    command.run()
  File "/snap/pycharm-professional/271/plugins/python/helpers/pydev/_pydev_bundle/pydev_console_types.py", line 35, in run
    self.more = self.interpreter.runsource(text, '<input>', symbol)
  File "/home/matteo/anaconda3/envs/pytorch_env/lib/python3.9/code.py", line 74, in runsource
    self.runcode(code)
  File "/home/matteo/anaconda3/envs/pytorch_env/lib/python3.9/code.py", line 90, in runcode
    exec(code, self.locals)
  File "<input>", line 1, in <module>
  File "/snap/pycharm-professional/271/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 198, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/snap/pycharm-professional/271/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/matteo/.config/JetBrains/PyCharm2021.3/scratches/scratch_16.py", line 234, in <module>
    loss = criterion(a,b)
  File "/home/matteo/anaconda3/envs/pytorch_env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/matteo/.config/JetBrains/PyCharm2021.3/scratches/scratch_16.py", line 208, in forward
    index_map = torch.sqrt(torch.sum(torch.pow(values, 2), dim=1))
 (Triggered internally at  /opt/conda/conda-bld/pytorch_1614378073166/work/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  Variable._execution_engine.run_backward(
Traceback (most recent call last):
  File "/home/matteo/anaconda3/envs/pytorch_env/lib/python3.9/code.py", line 90, in runcode
    exec(code, self.locals)
  File "<input>", line 1, in <module>
  File "/snap/pycharm-professional/271/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 198, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/snap/pycharm-professional/271/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/matteo/.config/JetBrains/PyCharm2021.3/scratches/scratch_16.py", line 237, in <module>
    loss.backward()
  File "/home/matteo/anaconda3/envs/pytorch_env/lib/python3.9/site-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/matteo/anaconda3/envs/pytorch_env/lib/python3.9/site-packages/torch/autograd/__init__.py", line 145, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function 'PowBackward0' returned nan values in its 0th output.

The error message indicates the NaNs are created in PowBackward, so I guess it might be this line of code?

index_map = torch.sqrt(torch.sum(torch.pow(values, 2), dim=1))

E.g. torch.sqrt of a zero input would yield an Inf gradient, which would then be pushed to a NaN:

x = torch.tensor([0.], requires_grad=True)
y = torch.sqrt(x)
print(y)
# > tensor([0.], grad_fn=<SqrtBackward0>)
y.backward()
print(x.grad)
# > tensor([inf])

so add small eps to methods which are creating these invalid gradients or make sure all input values are valid.