How to get the gradients of each sample in a batch efficiently

Can anyone help me with how to get the gradients for each sample in a mini-batch efficiently, not the one in the original forum.

Here is my code.
criterion = nn.CrossEntropyLoss(reduction=‘none’)
loss_fk = criterion(model(inputs), labels)
[torch.autograd.grad(loss, model.parameters(), retain_graph=True) for loss in loss_fk]

Thanks in advance.

Hi @likhilb9,

You can easily do this within the torch.func framework as follows,

from torch.func import vmap, grad, functional_call

criterion = nn.CrossEntropyLoss(reduction=‘none’)

def fcall(params, inputs):
  return functional_call(model, params, inputs) #functional call (params now  input).

def loss_fn(params, inputs, labels):
  outputs = fcall(params, inputs)
  loss = criterion(outputs, labels)
  return loss

params = dict(model.named_parameters()) #params must be dict (not list)

loss_wrt_params = vmap(grad(loss_fn, argnums=(0)), in_dims=(None, 0, 0))(params, inputs, labels) #torch.func.grad for reference.

Also, make sure to wrap code between three backticks ``` to get it formatted properly

EDIT: Corrected missing vmap call.

The above code is not working properly, and there is some issues with batch norm, and also I could’nt see you using vmap

Hi @likhilb9,

I’ve made the correction for the vmap/grad call.

In order to deal with batch norm layers, you can use the torch.func.replace_all_batch_norm_modules_ function, docs here.

You can find more information on patching the batch norm layer in the documentation here: Patching Batch Norm — PyTorch 2.2 documentation

mat1 and mat2 shapes cannot be multiplied (1024x25 and 400x120)

The same code is giving this error. Only while using VMap.
I have removed vmap and just tried implement the loss_fn function. It returned me the 64 losses.
The error only pops when I use this line vmap(grad(loss_fn, argnums=(0)), in_dims=(None, 0, 0))(params, inputs, labels) #torch.func.grad for reference.

Here is my test model.
class SimpleCNN(nn.Module):
def init(self):
super().init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

vmap removes the batch dimension, so you’ll need to use relative dims within torch functions. For example, a tensor of shape [B, N] would be viewed as shape [N] by vmap within your torch.flatten(x,1) function.

Any other solutions please?

Hi @likhilb9,

The output size of the intermediate vector is of a different size than the input nodes in the self.fc1 layer as stated in the error message. You can fix this by changing the number of inputs nodes in self.fc1.

You also need to change the torch.flatten dim arg to be relative rather than absolute, as the batch-size is removed when vmapping, an example code can be found below.

import torch
from torch import nn
from torch.func import vmap, jacrev, functional_call
import torch.nn.functional as F

class SimpleCNN(nn.Module):

  def __init__(self):
    super(SimpleCNN, self).__init__()
    
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(64, 120) #<<< wrong amount of input nodes (changed to 64)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = torch.flatten(x, -3) #changed from 1 to -3 (relative dims, aka 3rd from the right rather than 1st from left)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
    
net = SimpleCNN() #the network

def fcall(params, x):
  return functional_call(net, params, x)
  
params = dict(net.named_parameters())
x = torch.randn(4, 3, 20, 20) #random input size

y = net(x)

grad_params = vmap(jacrev(fcall, argnums=(0)), in_dims=(None,0))(params, x)
print(grad_params) #returns per-sample gradients with respect to params