How to get middle layer gradient for batch input?

Hi, I am trying to acquire the gradient in the last hidden layer for batch inputs.

Let says the last layer is a linear layer (512*100) and 100 is the number of classes.

For each sample, the gradient would be 100*512

Currently, my code is as follows.

losses = self.criterion_without_reduction(outputs, targets)
gradient_batch = []
for loss in losses:
    loss.backward(retain_graph=True)
    gradient_list = reversed(list(model.net.named_parameters()))
    # only output the gradient for the last layer
    for name, parameter in gradient_list:
        if 'weight' in name:
            gradient = parameter.grad.clone() # [column[:, None], row].resize_(100,100)
            gradient = gradient.unsqueeze_(0)
            gradient_batch.append(gradient.unsqueeze_(0))

            break
gradient_batch = torch.cat(gradient_batch, dim=0)

The above code will generate a tensor ([batch_size, 1, 100, 512])

However, the code runs relatively slow and I am wondering whether there is a more efficient way to do it?

You could access the gradient directly: model.gradient = model.net.fc.weight.grad.clone() (replace fc with the name of your classification layer) The speed is probably an issue with the loss.backward() stage, but this step is probably not the bottleneck.

Hi,

Forgive me if I make stupid mistakes…

I am running the code in the eval() mode and trying to get the gradient matrix for each input x, respectively.

For example, if we have 128 inputs (in a batch), we will get 128 different gradient matrixes.

I try to use model.net.fc.weight.grad.clone() to get the gradient, but the model.net.fc.weight.grad is a NoneType object…

Is there any solution for that?

Thank you in advance!

Ideally you should be running it after model.train(), but in this case there shouldn’t be any problem. I suppose the model.net.fc call should exist, if it does, I suspect the backward() call isn’t affecting the fc layer. Since your first code does work, I suspect you’re trying to access the classification module with the wrong name(fc in this case).