Gradients returning None in custom fully connected neural network loss function

Hi all!
Currently I’m working in a fully convolutional neural net that outputs a mask of zeroes and ones. This mask is than multiplied by a matrix of real numbers in a elementwise operation and all the results of the outputted matrix are summed.
The elementwise multiplication and summatory are utilized as my loss function, however the values obtained for the loss.item() are always the same, and with further investigation I’ve seen that my gradients are ‘None’. Is it a problem of my loss function? If so, how could I fix it?
Here is how my loss function currently is:
def my_loss(mask, matrix):
filter=mask*matrix
loss=torch.sum(torch.flatten(filter))
return loss

The loss function looks alright and should not detach the tensors from the computation graph.
Could you post a minimal and executable code snippet to reproduce the issue, please?

Sure, the example is written below:

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim

 def createMatrix(dim):
     x=np.random.rand(dim)*10
     y=np.random.rand(dim)*10
     D=list()
     m=x.shape[0]
     for xi,yi in zip(x,y):
         xp=xi-x
         yp=yi-y
         d=np.hstack((xp.reshape(m,1),yp.reshape(m,1)))
         norm_d=np.dot(d,d.T).diagonal()
         D.append(norm_d)
     return np.array(D, dtype='float32')

 def transformToTorch(tensor):
     transform_tensor=transforms.Compose([transforms.ToTensor(),])
     out=transform_tensor(tensor).unsqueeze(0)
     return out

 def createMiniBatch(mini_batch_size,dim,randomness=False):
     if randomness==True:
         random_dim=np.random.randint(low=1,high=dim)
         mini_batch=list()
         for i in range(mini_batch_size):
             matrix=createMatrix(random_dim)
             torch_matrix=transformToTorch(matrix)
             mini_batch.append(torch_matrix)
     elif randomness==False:
         mini_batch=list()
         for i in range(mini_batch_size):
             matrix=createMatrix(dim)
             torch_matrix=transformToTorch(matrix)
             mini_batch.append(torch_matrix)
     mini_batch=DataLoader(mini_batch)
     return mini_batch

 def my_loss(mask,input_matrix):
     filtered_matrix=mask*input_matrix
     loss_value=sum(torch.flatten(filtered_matrix))
     return loss_value

 class model(nn.Module):
     def __init__(self):
         super().__init__()
         self.encoder=nn.Sequential(
             nn.Conv2d(in_channels=1, out_channels=8,kernel_size=3),
             nn.Sigmoid(),
             nn.Conv2d(in_channels=8, out_channels=16,kernel_size=3),
             nn.BatchNorm2d(16),
             nn.Sigmoid(),
             nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3),
             nn.Sigmoid()
    )

    self.decoder=nn.Sequential(
        nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3),
             nn.BatchNorm2d(16),
             nn.Sigmoid(),
             nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3),
             nn.BatchNorm2d(8),
             nn.Sigmoid(),
             nn.ConvTranspose2d(in_channels=8, out_channels=1, kernel_size=3),
             nn.Sigmoid()
         )

     def forward(self,x):
         x=self.encoder(x)
         x=self.decoder(x)
         x=torch.round(x)
         return x

 n_net=model()

 epochs=10
 optimizer=optim.Adam(n_net.parameters(), lr=0.001)
 mini_batch=createMiniBatch(mini_batch_size=10,dim=10,randomness=False)

 running_loss=0
 for param in n_net.parameters():
     param.requires_grad = True
 for eopch in range(epochs):
     for i, data in enumerate(mini_batch,0):
         x=data[0].to(torch.float32)
         x.requires_grad=True
         optimizer.zero_grad()
         n_net.train()
         mask=n_net.forward(x).float()
         loss=my_loss(mask=mask,input_matrix=x)
         loss.backward()
         optimizer.step()
         running_loss+=loss.item()
     print('loss value:', running_loss, 'Grad:', loss.grad)
     running_loss=0

Returns:

loss value: 2319.3846282958984 Grad: None
loss value: 2319.3846282958984 Grad: None
loss value: 2319.3846282958984 Grad: None
loss value: 2319.3846282958984 Grad: None
loss value: 2319.3846282958984 Grad: None
loss value: 2319.3846282958984 Grad: None
loss value: 2319.3846282958984 Grad: None
loss value: 2319.3846282958984 Grad: None
loss value: 2319.3846282958984 Grad: None
loss value: 2319.3846282958984 Grad: None

You are checking the .grad attribute of the loss which is not kept by default.
Use:

        loss=my_loss(mask=mask,input_matrix=x)
        loss.retain_grad()
        loss.backward()

and you would see the passed torch.ones(1) grad as:

loss value: 73.96950149536133 Grad: tensor(1.)
loss value: 73.96950149536133 Grad: tensor(1.)
loss value: 73.96950149536133 Grad: tensor(1.)
loss value: 73.96950149536133 Grad: tensor(1.)
loss value: 73.96950149536133 Grad: tensor(1.)
loss value: 73.96950149536133 Grad: tensor(1.)
loss value: 73.96950149536133 Grad: tensor(1.)
loss value: 73.96950149536133 Grad: tensor(1.)
loss value: 73.96950149536133 Grad: tensor(1.)
loss value: 73.96950149536133 Grad: tensor(1.)

Thaks for the help @ptrblck!
Do you have any idea why the loss value is never updated, seems like the gradient is somehow still zero or None.

torch.round will create a zero gradient which is why all parameter gradients are also zero.

1 Like