Backward doesnt provide gradients!

I have been trying to get the gradients for input in this thread for more than a week now.
Looking at the official tutorial here, for getting gradients with respect to the input where the tensor used for backward is not a scaler, it says,

Now in this case y is no longer a scalar. torch.autograd could not compute the full Jacobian directly, but if we just want the vector-Jacobian product, simply pass the vector to backward as argument:
x = torch.randn(3, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
    y = y * 2

print(y)
v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(v)

print(x.grad)

However, when I try to sth like this the input gradient is always None!
The input will only have a gradient if I backprop using the loss1! which doesnt make any sense to me!

def fc_batchnorm_act(in_, out_, use_bn=True, act=nn.ReLU()):
    return nn.Sequential(nn.Linear(in_,out_),
                         act,
                         nn.BatchNorm1d(out_) if use_bn else nn.Identity())

class Reshape(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape 

    def forward(self, input):
        return input.view(self.shape)

class Contractive_AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(Reshape(shape=(-1, 28*28)),
                                     fc_batchnorm_act(28*28, 400, False))

        self.decoder = nn.Sequential(fc_batchnorm_act(400, 28*28, False, nn.Sigmoid()),
                                     Reshape(shape=(-1, 1, 28, 28)))                             
        
    def forward(self, input):
        outputs_e = self.encoder(input)
        outputs = self.decoder(outputs_e)
        return outputs_e, outputs

def loss_function(output_e, outputs, imgs, device):
 
    criterion = nn.MSELoss()
    assert outputs.shape == imgs.shape ,f'outputs.shape : {outputs.shape} != imgs.shape : {imgs.shape}'
    
    imgs.requires_grad = True 
    loss1 = criterion(outputs, imgs)
    # loss1.backward(retain_graph=True)
    output_e.backward(torch.ones(outputs_e.size()).to(device), retain_graph=True)
    print(imgs.grad)
    loss2 = torch.mean(pow(imgs.grad,2))
    imgs.requires_grad = False 
    imgs.grad.data.zero_()
    loss = loss1 + loss2 
    return loss

and this is how it is used :

for e in range(epochs):
    for i, (imgs, labels) in enumerate(dataloader_train):
        imgs = imgs.to(device)
        labels = labels.to(device)

        outputs_e, outputs = model(imgs)
        loss = loss_function(outputs_e, outputs, imgs,device)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'epoch/epoechs: {e}/{epochs} loss : {loss.item():.4f} ')

What am I missing here? I’d grately appreciate this

1 Like

For some weird reason if I set requires_grad outside of the loss_function, the backward works while setting the requires_grad inside loss function fails completely and not even a single iteration happens! This however wont happen if I doloss1.backward()! at all!

1 Like

Hi,

Isn’t the problem that imgs is not a leaf tensor? You can check with imgs.is_leaf.
In particular, .grad field is only populated for leaf Tensors. If you want it for other Tensors, you can use the imgs.retain_grad() function to get the .grad field populated for non-leaf Tensors.

1 Like

Thanks a lot I really appreciate your kind help. but why does it work when I backprop from loss1 !?
for example

def loss_function(output_e, outputs, imgs, device):
 
    criterion = nn.MSELoss()
    assert outputs.shape == imgs.shape ,f'outputs.shape : {outputs.shape} != imgs.shape : {imgs.shape}'
    
    imgs.requires_grad = True 
    loss1 = criterion(outputs, imgs)
    loss1.backward(retain_graph=True)
    #output_e.backward(torch.ones(outputs_e.size()).to(device), retain_graph=True)
    print(imgs.grad)
    loss2 = torch.mean(pow(imgs.grad,2))
    imgs.requires_grad = False 
    imgs.grad.data.zero_()
    loss = loss1 + loss2 
    return loss

If I use loss1.backward(retain_graph=True) I dont get any errors at all! but when I comment this and instead try to do output_e.backward(torch.ones(outputs_e.size()).to(device), retain_graph=True) it fails.
The only way it works is that I set the imgs.requires_grad inside the training loop that is :

the training loop :

for e in range(epochs):
    for i, (imgs, labels) in enumerate(dataloader_train):
        imgs = imgs.to(device)
        labels = labels.to(device)

        imgs.requires_grad_(True)

        outputs_e, outputs = model(imgs)

        loss = loss_function(outputs_e, outputs, imgs, lam,device)
        imgs.requires_grad_(False)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'epoch/epochs: {e}/{epochs} loss: {loss.item():.4f}')

and the loss :

def loss_function(output_e, outputs, imgs, device):
 
    criterion = nn.MSELoss()
    assert outputs.shape == imgs.shape ,f'outputs.shape : {outputs.shape} != imgs.shape : {imgs.shape}'
    
    loss1 = criterion(outputs, imgs)
    #loss1.backward(retain_graph=True)
    output_e.backward(torch.ones(outputs_e.size()).to(device), retain_graph=True)
    loss2 = torch.mean(pow(imgs.grad,2))
    imgs.grad.data.zero_()
    loss = loss1 + loss2 
    return loss

and now both loss.backward and outputs_e.backward() work without a hitch!

Hoo,
The thing is that the requires_grad flag is not retroactive. If you forwarded through your model with the requires_grad flag being False, then the you cannot get gradients. Your loss1 gives you some gradients because you use imgs again after setting the flag.

1 Like

Thanks a gazillion times :slight_smile: