Access to grad_output (the input to the backward)

Hi,

In this tutorial, it is mentioned that backward receives a tensor (grad_output) but when we run loss.backward(), loss is a scalar. I see scalar is also a tensor but if grad_output is something other than the loss then my question is how I can have access to the grad_output that is computed based on loss then?

@staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

Hi,

When you call t.backward(), if t is not a tensor with a single element, it will actually complain and ask the user to provide the first grad_output (as a Tensor of the same size as t).
In the particular case where t has a single element, grad_output defaults to torch.Tensor([1]) because that way what is computed are gradients.
Does that answer your question?

2 Likes

Thank you for your reply.
So backward function always assumes a 1D input. I thought backward network can be considered as the transpose of the forward network in the linear case (fully connected and conv), so I thought backward function should accept an input the size of the output of the layer. That’s why I am puzzled why backward accepts 1D input.
Here is what I want to implement:
According to the chain rule, we can view the process of computing the gradients of loss w.r.t. parameters as the dLoss/dy * dy/dparams. I am building a network that computes the second part i.e. dy/dparams (similar to what happens in the backward function of each module in Pytorch). Now, my question is what would be the input to that network so that it computes the same gradients as loss.backward() calculates for each layer. This way, I would be able to bypass loss.backward() and use my own backward function. But first, for a sanity check, I want to implement backpropagation using this method (feeding gradients computed in loss to a forward network instead of using loss.backward())

Thanks

Hi,

There is a small difference between theory and practice here:

  • In theory yes, backward should work only with 1D Tensors and vector Jacobian product. Where the Jacobian assumes 1D input and 1D output.
  • In practice, your input is not a 1D and the output is not either. So you will get a dLoss/dy which is not 1D but the same shape as y. and you should return something that is the same shape as params. But it should be computed as if these were 1D * 2D matrices.

That being said, keep in mind that if your net is a function f with input x and output y (assumed 1D here). Then what you will compute by doing y.backward(v) is v.t() J where J is the jacobian of f of size (len(y), len(x)).

Thank you for clarification.
I am still having a hard time replicating Pytorch’s loss.backward() from scratch.
I computed dLoss/dy when Loss is a CorssEntropy loss and fed it to my backward network. I compared the loss and also the gradients computed by my network to the same network being trained with loss.backward(). But, gradients are not the same the ones computed by loss.backward() and I am puzzled why. I would appreciate it if you could point me to where my logic is wrong. Here is the code:

import torch.nn as nn
import torch.optim as optim
import torchvision

import numpy as np
import scipy.stats as ss
import scipy
import torch
from torchvision import datasets, transforms

imagesetdir = '/.' #path_prefix+'/Data/'

use_cuda = True
batch_size = 64
kwargs = {'num_workers': 0, 'pin_memory': True, 'drop_last':True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(imagesetdir, train=True, download=True,
                    transform=transforms.Compose([
                        transforms.Resize(32),
                        transforms.ToTensor(),
                       
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(imagesetdir, train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=False, **kwargs)



class Forward(nn.Module):
    def __init__(self):
        super(Forward, self).__init__()
        self.fc_0 = nn.Linear(1024, 40, bias=False)
        self.fc_1 = nn.Linear(40, 10, bias=False)

    def forward(self, x):
        x0 = self.fc_0(x)
        x1 = self.fc_1(x0)

        return x1, [x0, x1]

class Backward(nn.Module):
    def __init__(self):
        super(Backward, self).__init__()
        self.fc_1 = nn.Linear(10, 40, bias=False)
        self.fc_0 = nn.Linear(40, 1024, bias=False)

    def forward(self, x):
        x1 = self.fc_1(x)
        x0 = self.fc_0(x1)

        return [x0, x1]

def transpose_weights(state_dict):

    state_dict_new = {}
    for k, item in state_dict.items():
        state_dict_new.update({k: item.t()})
    return state_dict_new


modelF = Forward().cuda() # main model
modelB = Backward().cuda() # backward network to compute gradients for modelF

modelC = Forward().cuda() # Forward Control model to compare to BP

optimizerC = optim.Adam(modelC.parameters(), lr=0.0001)
optimizer = optim.Adam(modelF.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

# -------Implementing BP without using loss.backward() ----------

n_classes = 10
onehot = torch.zeros(train_loader.batch_size, n_classes).cuda()
Softmax = nn.Softmax(dim=1)
for epoch in range(20):

    loss_running = 0
    lossC_running = 0
    for i, (inputs, target) in enumerate(train_loader):

        inputs = inputs.view(train_loader.batch_size, -1).cuda()
        target = target.cuda()
        onehot.zero_()
        onehot.scatter_(1, target.view(train_loader.batch_size,-1), 1)
        
        # ------------- BP Control ------------------------------------------
        outputsC, _ = modelC(inputs)
        lossC = criterion(outputsC, target)
        paramsC = [p for p in modelC.parameters()]
        optimizerC.zero_grad()
        lossC.backward()
        optimizerC.step()

        lossC_running += lossC.item()

        # -------------Implementing BP bypassing loss.backward()-------------
        modelB.load_state_dict(transpose_weights(modelF.state_dict()) )

        outputs, activationsF = modelF(inputs)
        probs = Softmax(outputs.detach())
        # the gradient of CrossEntropy is pi-yi (pi: softmax of the output, yi:onehot label)
        activationsB = modelB(probs - onehot)

        loss = criterion(outputs, target)
        optimizer.zero_grad()
        

        ParamsF = [p for p in modelF.parameters() if p.requires_grad]

        # copy the backward gradients into the parameter grads.
        for ia, a in enumerate(activationsB):
            pF = ParamsF[ia] # parameters of the forward model
            h = activationsF[ia] # forward activations

            # computing the gradients by multiplying forward activations and backward activations
            pF.grad = torch.matmul(h.t().detach(), a.clone().detach())
                
        optimizer.step()
        loss_running += loss.item()
    
    print(epoch, loss_running/(i+1), lossC_running/(i+1))

Don’t you want to make sure modelF and modelC have the same weights?

modelC is only for control, so ideally I want to train modelF (using modelB instead of loss.backward()) and to train modelC using the conventional loss.backward() and then compare the parameters (weights) of modelF to modelC to see if I could replicate the logic behind loss.backward() when training modelF by modelB.

Unfortunately there is no guaranty that two models with the same architecture will converge to exactly the same solution.

Also keep in mind that to train your network, you want the gradients wrt the weights. Not the input or activations. So your use of modelB might not be right. In particular, the gradients wrt the weights are given by dL/dy dy/dparams = dL/dy * input for a Linear layer.

A simpler check for this could be:

initialize modelF
initialize modelB with transpose of modelF
initialize modelC with modelF

Get a single sample
do a forward / backward in modelC
do forward in modelF, modelB
check the values computed by modelB and the .grad in modelC
1 Like

Thank you for the suggestions.
I changed the loss function to MSELoss and the loss for my modelF starts to decrease. However, it does not give me the same gradients as modelC (computed by loss.backward()). To be able to check what the input to my backward model (modelB) should be, I need to have access to the inputs to backward function within nn.Linear module to be able to compare the input to my backward model (modelB) and the input to the backward function in nn.Module. Is there any way to access to grad_input for my control model (modelC)?

You can register a Tensor hook on the output of the layer. That way you will have access to the gradient for that Tensor.

2 Likes

I could successfully bypass loss.backward() in a simple network for both MSE loss and CrossEntropy loss and I replaced it with a backward network. For future reference, I am sharing the code. Thanks albanD!

import torch.nn as nn
import torch.optim as optim
import torchvision

import numpy as np
import scipy.stats as ss
import scipy
import torch
from torchvision import datasets, transforms
import scipy.stats as ss
import matplotlib.pylab as plt

imagesetdir = './'

use_cuda = True
batch_size = 1024
kwargs = {'num_workers': 0, 'pin_memory': True, 'drop_last':True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(imagesetdir, train=True, download=True,
                    transform=transforms.Compose([
                        transforms.Resize(32),
                        transforms.ToTensor(),
                       
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(imagesetdir, train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=False, **kwargs)



class Forward(nn.Module):
    def __init__(self):
        super(Forward, self).__init__()
        self.fc_0 = nn.Linear(1024, 40, bias=False)
        self.fc_1 = nn.Linear(40, 10, bias=False)

    def forward(self, x):
        x0 = self.fc_0(x)
        x1 = self.fc_1(x0)

        return x1, [x, x0, x1]

class Backward(nn.Module):
    def __init__(self):
        super(Backward, self).__init__()
        self.fc_1 = nn.Linear(10, 40, bias=False)
        self.fc_0 = nn.Linear(40, 1024, bias=False)

    def forward(self, x):
        x1 = self.fc_1(x)
        x0 = self.fc_0(x1)

        return x0, [x1, x]

def transpose_weights(state_dict):

    state_dict_new = {}
    for k, item in state_dict.items():
        state_dict_new.update({k: item.t()})
    return state_dict_new

def corr(t0, t1):
    return ss.pearsonr(t0.view(-1).detach().cpu().numpy(),t1.view( -1).detach().cpu().numpy() )

# A simple hook class that returns the input and output of a layer during forward/backward pass
class Hook():
    def __init__(self, module, backward=False):
        if backward==False:
            self.hook = module.register_forward_hook(self.hook_fn)
        else:
            self.hook = module.register_backward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.input = input
        self.output = output
    def close(self):
        self.hook.remove()

modelF = Forward().cuda() # main model
modelB = Backward().cuda() # backward network to compute gradients for modelF

modelC = Forward().cuda() # Forward Control model to compare to BP
modelC.load_state_dict(modelF.state_dict())
modelB.load_state_dict(transpose_weights(modelF.state_dict()) )

optimizerC = optim.Adam(modelC.parameters(), lr=0.0001)
optimizer = optim.Adam(modelF.parameters(),  lr=0.0001)
criterion = nn.CrossEntropyLoss() #nn.MSELoss() #

# -------Implementing BP without using loss.backward() ----------
hookC = [Hook(layer[1], backward=True) for layer in list(modelC._modules.items())]

n_classes = 10
onehot = torch.zeros(train_loader.batch_size, n_classes).cuda()
Softmax = nn.Softmax(dim=1)
for epoch in range(20):

    loss_running = 0
    lossC_running = 0
    for i, (inputs, target) in enumerate(train_loader):

        inputs = inputs.view(train_loader.batch_size, -1).cuda()
        target = target.cuda()
        onehot.zero_()
        onehot.scatter_(1, target.view(train_loader.batch_size,-1), 1)
        
        # ------------- BP Control ------------------------------------------
        outputsC, activationsC = modelC(inputs)
        lossC = criterion(outputsC, target)
        optimizerC.zero_grad()
        lossC.backward()
        optimizerC.step()

        lossC_running += lossC.item()
        ParamsC = [p for p in modelC.parameters()]

        # -------------Implementing BP bypassing loss.backward()-------------
        modelB.load_state_dict(transpose_weights(modelF.state_dict()) )

        outputs, activationsF = modelF(inputs)
        probs = Softmax(outputs.detach())
        # the gradient of CrossEntropy is pi-yi (pi: softmax of the output, yi:onehot label)
        grad_input = onehot - probs # for CrossEntropyLoss
        
        #grad_input = (2/n_classes) * (onehot-outputs) # for MSEloss
        
        recons, activationsB = modelB(grad_input)

        loss = criterion(outputs, target)
        optimizer.zero_grad()
        

        ParamsF = [p for p in modelF.parameters() if p.requires_grad]

        # copy the backward gradients into the parameter grads.
        for ip, pF in enumerate(ParamsF):
            # parameters of the control model
            pC = ParamsC[ip]
            hC = hookC[::-1][ip].output[0]

            aF = activationsF[ip] # forward activations
            aB = activationsB[ip] # backward activations
            
            pF.grad = -torch.matmul(aB.t().detach(), aF.clone().detach())
            # pF.grad should be close to pC.grad
            
        optimizer.step()
        loss_running += loss.item()
    
    print('Epoch %d: Loss= %.3f, Loss_Control= %.3f'%(epoch, loss_running/(i+1), lossC_running/(i+1)))
4 Likes