Model param.grad is None, how to debug?

I have a code that accumulates grad of each layer after .backward() call on loss.

It was working but after some change, I am seeing a model where all parameter with grad None

I guess since grad is None, no training is happening.

When does it usually happen? What should I check to find out the cause?

Couple notes about the model …
The model that I am having this issue is Discriminator of a GAN.
I generate a data from generator but call detach before I feed it to the discriminator.
also does using GPU make any difference in computation graph? t seems to be fine on CPU.

2 Likes

Detaching the output of your generator is fine, if you don’t need gradients in the generator but only in the discriminator.

Usually you get None gradients, if the computation graph was somehow detached, e.g. by calling .item(), numpy(), rewrapping a tensor as x = torch.tensor(x, requires_grad=True), etc.

Do you get any valid gradients in your discriminator? Could you check the requires_grad attributes of the parameters?

It should not make a difference, so a reproducible code snippet would be good to have to debug it.

4 Likes

So I did some investigation

        # point 1
        output, _, _ = self.discriminator['model'](data)
        # point 2
        loss = self.discriminator['loss_fn'](output, target)
        # point 3
        loss.backward()
        # point 4
        self.discriminator['optimizer'].step()

At each point I printed requires_grad and grad

point 1 : data False None
point 1 : target False None
point 1 :  module.dis_model.0.weight True None
point 1 :  module.dis_model.0.bias True None
point 1 :  module.dis_model.2.weight True None
point 1 :  module.dis_model.2.bias True None
point 1 :  module.discriminator.0.weight True None
point 1 :  module.discriminator.0.bias True None
point 1 :  module.encoder.0.weight True None
point 1 :  module.encoder.0.bias True None
point 1 :  module.co_encoder.0.weight True None
point 1 :  module.co_encoder.0.bias True None

point 2 : output True None

point 3 : loss True None

point 4 : output True None
point 4 : target False None
point 4 : loss True None
point 4 : output True None
point 4 :  module.dis_model.0.weight True None
point 4 :  module.dis_model.0.bias True None
point 4 :  module.dis_model.2.weight True None
point 4 :  module.dis_model.2.bias True None
point 4 :  module.discriminator.0.weight True None
point 4 :  module.discriminator.0.bias True None
point 4 :  module.encoder.0.weight True None
point 4 :  module.encoder.0.bias True None
point 4 :  module.co_encoder.0.weight True None
point 4 :  module.co_encoder.0.bias True None

The network looks like this

  Discriminator(
  (dis_model): Sequential(
    (0): Linear(in_features=784, out_features=32, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
  )
  (discriminator): Sequential(
    (0): Linear(in_features=32, out_features=1, bias=True)
    (1): Sigmoid()
  )
  (encoder): Sequential(
    (0): Linear(in_features=32, out_features=5, bias=True)
    (1): Softmax()
  )
  (co_encoder): Sequential(
    (0): Linear(in_features=32, out_features=0, bias=True)
    (1): Softmax()
  )
)

In this training, I only care about output from discriminator, which goes through dis_model first

Should the requires_grad of input data be True? is that the issue?

The model definition looks alright and shouldn’t be the problem:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.dis_model = nn.Sequential(
            nn.Linear(784, 32),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(32, 32),
            nn.LeakyReLU(negative_slope=0.2)
        )
        self.discriminator = nn.Sequential(
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.dis_model(x)
        x = self.discriminator(x)
        return x
    
model = Discriminator()
x = torch.randn(1, 784)
output = model(x)
output.mean().backward()

for name, param in model.named_parameters():
    print(name, param.grad)
>dis_model.0.weight tensor([[-0.0013,  0.0058, -0.0014,  ..., -0.0057, -0.0007, -0.0087],
        [-0.0004,  0.0016, -0.0004,  ..., -0.0016, -0.0002, -0.0024],
        [ 0.0008, -0.0037,  0.0009,  ...,  0.0037,  0.0004,  0.0056],
        ...,
        [ 0.0002, -0.0011,  0.0003,  ...,  0.0011,  0.0001,  0.0016],
        [ 0.0048, -0.0215,  0.0052,  ...,  0.0212,  0.0025,  0.0322],
        [-0.0006,  0.0026, -0.0006,  ..., -0.0026, -0.0003, -0.0039]])
dis_model.0.bias tensor([ 5.0931e-03,  1.3960e-03, -3.2967e-03,  9.4221e-03, -2.1729e-03,
        -1.0545e-03, -2.0167e-04, -6.2622e-04,  3.3179e-02,  4.4160e-03,
        -1.2255e-03, -2.1004e-03,  6.5288e-04, -9.8875e-04, -2.2054e-03,
        ...

No. That would only be necessary, if you want to update your input, which is probably not the case.

Could you post the definition of the complete model?

    def __init__(self, input_dim, cat_dim, co_cat_dim):
        super(Discriminator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            # if normalize:
            #     layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        self.dis_model = nn.Sequential(
            *block(input_dim, 32, normalize=False),
            *block(32, 32),
        )

        # Output layers
        self.discriminator = nn.Sequential(
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

        self.encoder = nn.Sequential(
            nn.Linear(32, cat_dim),
            nn.Softmax()
        )

        self.co_encoder = nn.Sequential(
            nn.Linear(32, co_cat_dim),
            nn.Softmax()
        )

    def forward(self, input):
        out = self.dis_model(input)
        validity = self.discriminator(out)
        label = self.encoder(out)
        colabel = self.co_encoder(out)

        return validity, label, colabel

and loss function is just

def mse_loss(output, target):
    return nn.MSELoss()(output, target)

Also this model definition works:


model = Discriminator(1, 1, 1)
x = torch.randn(1, 1)
target = torch.randn(1, 1)
outputs = model(x)
loss = mse_loss(outputs[0], target)
loss.backward()

for name, param in model.named_parameters():
    if param.grad is not None:
        print(name, param.grad.sum())
    else:
        print(name, param.grad)

> dis_model.0.weight tensor(-0.0065)
dis_model.0.bias tensor(-0.0575)
dis_model.2.weight tensor(-1.0274)
dis_model.2.bias tensor(-0.1347)
discriminator.0.weight tensor(0.9048)
discriminator.0.bias tensor(0.3326)
encoder.0.weight None
encoder.0.bias None
co_encoder.0.weight None
co_encoder.0.bias None
2 Likes

Yes, I have just tried it as well…

So since Loss is also requires_grad true, I should see some value for grad for loss when I call backward correct?

Given that I do not see anything for loss, I guess I can suspect that it can possibly be source of error?

By default the loss gradient won’t be retained, as it’s usually a 1 (or you would pass the loss gradient manually to loss.backward(gradient).
If you want to print it, you would need to call:

loss.retain_grad()
loss.backward()
print(loss.grad)
> tensor(1.)
2 Likes

I actually have narrowed the issue down.

it seems to be happening when I tried to use more than 1 GPU.

def prepare_device(n_gpu_use):
    """
    setup GPU device if available, move model into configured device
    """
    n_gpu = torch.cuda.device_count()
    if n_gpu_use > 0 and n_gpu == 0:
        print("Warning: There\'s no GPU available on this machine,"
                            "training will be performed on CPU.")
        n_gpu_use = 0
    if n_gpu_use > n_gpu:
        print("Warning: The number of GPU\'s configured to use is {}, but only {} are available "
                            "on this machine.".format(n_gpu_use, n_gpu))
        n_gpu_use = n_gpu
    device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
    list_ids = list(range(n_gpu_use))
    return device, list_ids
self.device, self.device_ids = prepare_device(4)

self.discriminator['model'] = self.discriminator['model'].to(self.device)
if len(self.device_ids) > 1:
      self.discriminator['model'] = torch.nn.DataParallel(self.discriminator['model'], device_ids=self.device_ids)

and input data is not parallelized. (not even sure if it has to be…)
I just call .to(self.device) on those tensors

sir i have posted a question, can you please help me
here is the link:

hi
I have the same problem (param.grad=None) after backward
I’m not warpping tensors/using numpy/item in my code
I wrote required grad=True and checked that they are leaf
do you have any other idea what can cause it?

The computation graph might have (accidentally) been detached somewhere.
Could you post your model definition so that we could have a look, please?

class CPC15_BEASTpred2(torch.nn.Module):

def __init__(self):
    super(CPC15_BEASTpred2, self).__init__()
    self.CPC18_getDist = CPC18_getDist2()
    self.CPC15_BEASTsimulation = CPC15_BEASTsimulation2()
    self.SIGMA = torch.nn.parameter.Parameter(torch.tensor(7, dtype=torch.float32, requires_grad=True))
    self.KAPA = torch.nn.parameter.Parameter(torch.tensor(3, dtype=torch.float32, requires_grad=True))
    self.BETA = torch.nn.parameter.Parameter(torch.tensor(2.6, dtype=torch.float32, requires_grad=True))
    self.GAMA = torch.nn.parameter.Parameter(torch.tensor(0.5, dtype=torch.float32, requires_grad=True))
    self.PSI = torch.nn.parameter.Parameter(torch.tensor(0.07, dtype=torch.float32, requires_grad=True))
    self.THETA = torch.nn.parameter.Parameter(torch.tensor(1, dtype=torch.float32, requires_grad=True))

def forward(self, Ha, pHa, La, LotShapeA, LotNumA, Hb, pHb, Lb, LotShapeB, LotNumB, Amb, Corr,B_rate):
    DistA = self.CPC18_getDist(Ha, pHa, La, LotShapeA, LotNumA)
    DistB = self.CPC18_getDist(Hb, pHb, Lb, LotShapeB, LotNumB)
    Prediction = torch.zeros( size=(5,1), dtype=torch.float32, requires_grad=True)
    nSims = 1
    for sim in range(0, nSims):
        simPred = self.CPC15_BEASTsimulation(DistA, DistB, Amb, Corr,self.SIGMA,self.KAPA,self.BETA,self.GAMA,self.PSI,self.THETA)
        Prediction = torch.add(Prediction, (1 / nSims) * simPred)
    return Prediction

as you can see I use module in tha main module, it’s too long to show here all everything, so I’m trying to think what are the things that caused the omputation graph to detach
thanks for the help!

Generally, these operations will detach the tensor from the computation graph:

  • using another library such as numpy without writing custom autograd.Functions including the custom backward function
  • rewrapping tensors via: x = torch.tensor(x)
  • explicitly detaching the tensor via: x = x.detach()
  • using the .data attribute might not detach the tensor directly, but you can run in unwanted side effects so don’t use .data

Do slicing-operations like t[:,:,:x,:y] fall under this category as well?

No, slicing operations should work and will not detach the tensor from the computation graph.

Hi @ptrblck_de, so what’s the solution when we use a rewrapping of a tensor via torch.flatten().

For example model is a deep conv net and at the penualtimate layer, I want to flatten the 4-d volume (B, Cout, H, W) to a 2-d tensor of size (B, CoutHW) so that i can feed it to a nn.linear layer of size (CoutHW, nClasses).

I’m getting grad none for the linear layer (fc1 below) as in the forward function I do.

forward():

#x4 is the out of a conv layer of size (B, Cout, H, W)
x4 = torch.flatten(x4)
x5= self.fc1(x4) # where self.fc1 =nn.Linear( CoutHW, nClasses)

I jave tried x4 =x4.view(B,-1) and this also gave the same error that grad =None for the fc1 layer.

The flattening of a tensor should not detach it, so I guess the root cause of the issue you are seeing comes from another part or the model.
Could you post the complete model definition so that we could have a look, please?