Grad is None for nn.Parameter

Hi,

I am trying to run the following code and it gives self.weight_mask.grad and self.bias_mask.grad is None

I have checked many posts related to this grad is None problem. From un-differentiable function to is_leaf and retain_grad(). None of them works in this case.

class FC_DropConnect(nn.Module):
    def __init__(self, dim, mlp_hidden_dim):
        super().__init__()
        '''
        drop connections in MLP by multiply MLP.weight to a binary mask
        '''
        self.fc = nn.Linear(dim, mlp_hidden_dim)

        self.weight_mask = nn.Parameter(torch.rand(self.fc.weight.shape))
        torch.nn.init.normal_(self.weight_mask, std=.02)

        self.bias_mask = nn.Parameter(torch.rand(self.fc.bias.shape))
        torch.nn.init.normal_(self.bias_mask, std=.02)

        self.binary = StraightThroughEstimator()

    # forward for StraightThroughEstimator()
    def forward(self, x):
        self.fc.weight.data *= self.binary(self.weight_mask) #weight_mask
        self.fc.bias.data *= self.binary(self.bias_mask) # bias_mask
        x = self.fc(x)
        return x

The Straight Through Estimator will act as a differentiable binary activation layer.

class STEFunction(Function):
    '''
    https://discuss.pytorch.org/t/binary-activation-function-with-pytorch/56674/4
    '''
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

class StraightThroughEstimator(nn.Module):
    '''
    https://discuss.pytorch.org/t/binary-activation-function-with-pytorch/56674/4
    '''
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = STEFunction.apply(x)
        return x

In a minimal code that tests the Straight Through Estimator. The grad of nn.Paremeter is not None

##### 1st code: check grad StraightThroughEstimator ##########
x = nn.Parameter(torch.randn(5,3))
estimator = StraightThroughEstimator()
b = estimator(x)   # b consists of 0 and 1
y = torch.randn(5,3)
y[1:3] = 1
b.backward(y)
print('y', y)
print('x.requires_grad',x.requires_grad)   # True
print('x.grad', x.grad)                    # not None
##########################
x.requires_grad True
x.grad tensor([[ 0.3226,  1.0000, -0.0620],
        [ 1.0000,  1.0000,  1.0000],
        [ 1.0000,  1.0000,  1.0000],
        [-0.7098, -0.1727, -0.2121],
        [-1.0000, -0.7743, -0.1660]])

However, when I check grad with the class FC_DropConnect, the self.weight_mask.grad is always None no matter what I tried. The fc.weight_mask.grad_fn is None

##### 2nd code: check grad FC_DropConnect ##########
fc = FC_DropConnect(5, 3)
x = torch.randn(5)
y = torch.randn(3)
fc.zero_grad()
x_hat = fc(x)
print('fc.weight_mask.is_leaf', fc.weight_mask.is_leaf)  # True
fc.weight_mask.retain_grad()
x_hat.backward(y, retain_graph=True)

print('fc.weight_mask', fc.weight_mask)
print('fc.weight_mask.requires_grad', fc.weight_mask.requires_grad)    # True
print('fc.weight_mask.grad', fc.weight_mask.grad)      # None
print('fc.weight_mask.grad_fn ', fc.weight_mask.grad_fn) # None
###################################
fc.weight_mask Parameter containing:
tensor([[-0.0390,  0.0063,  0.0014, -0.0140,  0.0211],
        [-0.0297, -0.0519,  0.0291,  0.0076, -0.0237],
        [-0.0107,  0.0175, -0.0200,  0.0033,  0.0198]], requires_grad=True)
fc.weight_mask.requires_grad True
fc.weight_mask.grad None
fc.weight_mask.grad_fn  None

I don’t understand why the 1st code the nn.Parameters grad is not None and the 2nd code is None.

Thank you

fc.weight_mask was never used in a differentiable way, but fc.fc was.
Here:

        self.fc.weight.data *= self.binary(self.weight_mask) #weight_mask
        self.fc.bias.data *= self.binary(self.bias_mask) # bias_mask

you are using the deprecated .data attribute to manipulate the trainable parameters of self.fc inplace, which skips Autograd.
Remove the .data attribute usage and create a new activation tensor instead, which can then be used via the functional API:

class STEFunction(torch.autograd.Function):
    '''
    https://discuss.pytorch.org/t/binary-activation-function-with-pytorch/56674/4
    '''
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

class StraightThroughEstimator(nn.Module):
    '''
    https://discuss.pytorch.org/t/binary-activation-function-with-pytorch/56674/4
    '''
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = STEFunction.apply(x)
        return x

class FC_DropConnect(nn.Module):
    def __init__(self, dim, mlp_hidden_dim):
        super().__init__()
        '''
        drop connections in MLP by multiply MLP.weight to a binary mask
        '''
        self.fc = nn.Linear(dim, mlp_hidden_dim)

        self.weight_mask = nn.Parameter(torch.rand(self.fc.weight.shape))
        torch.nn.init.normal_(self.weight_mask, std=.02)

        self.bias_mask = nn.Parameter(torch.rand(self.fc.bias.shape))
        torch.nn.init.normal_(self.bias_mask, std=.02)

        self.binary = StraightThroughEstimator()

    # forward for StraightThroughEstimator()
    def forward(self, x):
        self.fc.weight.data *= self.binary(self.weight_mask) #weight_mask
        self.fc.bias.data *= self.binary(self.bias_mask) # bias_mask
        x = self.fc(x)
        return x
    
model = FC_DropConnect(1, 1)
x = torch.randn(1, 1)
out = model(x)
out.backward()
print(model.fc.weight.grad)
# tensor([[-0.5319]])
print(model.weight_mask.grad)
# None


class FC_DropConnect(nn.Module):
    def __init__(self, dim, mlp_hidden_dim):
        super().__init__()
        '''
        drop connections in MLP by multiply MLP.weight to a binary mask
        '''
        self.fc = nn.Linear(dim, mlp_hidden_dim)

        self.weight_mask = nn.Parameter(torch.rand(self.fc.weight.shape))
        torch.nn.init.normal_(self.weight_mask, std=.02)

        self.bias_mask = nn.Parameter(torch.rand(self.fc.bias.shape))
        torch.nn.init.normal_(self.bias_mask, std=.02)

        self.binary = StraightThroughEstimator()

    # forward for StraightThroughEstimator()
    def forward(self, x):
        weight = self.fc.weight * self.binary(self.weight_mask) #weight_mask
        bias = self.fc.bias * self.binary(self.bias_mask) # bias_mask
        x = F.linear(x, weight, bias)
        return x
    
model = FC_DropConnect(1, 1)
x = torch.randn(1, 1)
out = model(x)
out.backward()
print(model.fc.weight.grad)
# tensor([[-0.0658]])
print(model.weight_mask.grad)
# tensor([[0.0375]])

@ptrblck

I understand now. Thank you very much. I’d been struggled with this for days.

1 Like

Hello, I am having a similar issue where the parameter gradients do not compute (show as none). For context, I am developing a simple generative model to characterize x-y position given position and velocity at previous timestep and some history. The neural network structure is arbitrary, however it may be noted that in addition to the standard forward method there is a predict method that either returns the estimated mean vector and covariance matrix or returns a sample drawn from multivariate normal distribution.

Below is a snippet of my training loop. Not sure why gradients won’t compute. I try printing parameter gradients after loss.backward() call and it displays none. However, when I print parameter.requires_grad it displays true. Thanks for all help, plz let me know if there’s any further info I can provide.

Hi @mrichards ,

You should also post the minimal code of the network for others to help, at least for me.

A code snippet of the neural network is pasted below. Thanks for all assistance!

Unfortunately, you have posted a screenshot instead of wrapping the code into three backticks ```, which would make debugging easier.
Also, your code does not look executable as e.g. the input shapes etc. are missing so could you add these and post the code directly?

Hi @ptrblck ,

I try to understand this Grad computation with another problem. This time the grad is all 0.

In the code below, I want to do convex combination of tensors [n_p, len, dim] and 1-dimension alpha tensor [n_p] by broadcasting.

After backward(), only the tensors [n_p, len, dim] have gradient but the alpha tensor [n_p] grad is all 0.

I found this Pytorch forum discussion that is similar to the problem that I have here but I still dont understand why the grad is all 0.

This code is similar to this MoCo github which also used torch.einsum to compute logits for the loss’s input

import torch
import torch.nn as nn
from torch.nn import functional as F

def forward_mprompts(emb, emb_deep, alphas, alphas_deep):
    '''
    convex combination
    :param emb: [n_p, len, dim]
    :param emb_deep: [n_p, n_layers, len, dim]
    :param alphas: [n_p]
    :param alphas_deep: [n_p]
    :return: out_emb [1, len, dim], out_emb_deep [n_layers, len, dim]
    '''

    out_emb = torch.einsum('ijk,i->ijk', emb, F.softmax(alphas)).sum(0).unsqueeze(0)

    out_emb_deep = torch.einsum('ijkf,i->ijkf', emb_deep, F.softmax(alphas_deep)).sum(0)

    return out_emb, out_emb_deep

def MSE_loss(x, y):
    return F.mse_loss(x,y)

n_p= 5
n_ctx = 3
dim = 4
n_layers = 2

emb = nn.Parameter(torch.zeros(n_p, n_ctx, dim))

alphas = nn.Parameter(torch.rand(n_p))

emb_deep = nn.Parameter(torch.zeros(n_p, n_layers - 1, n_ctx, dim))

alphas_deep = nn.Parameter(torch.rand(n_p))

out_emb, out_emb_deep = forward_mprompts(emb, emb_deep, alphas, alphas_deep)

y_emb = torch.rand(1, n_ctx, dim)
y_emb_deep = torch.rand(n_layers - 1, n_ctx, dim)

loss = MSE_loss(out_emb, y_emb)
loss_deep = MSE_loss(out_emb_deep, y_emb_deep)

print(alphas.is_leaf)   # True
alphas.retain_grad()

print(alphas_deep.is_leaf)  # True
alphas_deep.retain_grad()

loss.backward()
loss_deep.backward()

print(emb.grad)      # Tensor torch.Size([5, 3, 4])
print(emb_deep.grad)  # Tensor torch.Size([5, 2, 3, 4])
print(alphas.grad)    # tensor([0., 0., 0., 0., 0.])
print(alphas_deep.grad)   # tensor([0., 0., 0., 0., 0.])

In addition, the code above also gives the warning

UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
  out_emb = torch.einsum('ijk,i->ijk', [emb, F.softmax(alphas)]).sum(0).unsqueeze(0)

Thank you

I might misunderstand your code, so please correct me, but it seems you are multiplying emb with F.softmax(alphas) while emb is initialized with zeros?
If so, then a zero gradient for alphas would be expected.