Can anyone help me with how to get the gradients for each sample in a mini-batch efficiently, not the one in the original forum.
Here is my code.
criterion = nn.CrossEntropyLoss(reduction=‘none’)
loss_fk = criterion(model(inputs), labels)
[torch.autograd.grad(loss, model.parameters(), retain_graph=True) for loss in loss_fk]
mat1 and mat2 shapes cannot be multiplied (1024x25 and 400x120)
The same code is giving this error. Only while using VMap.
I have removed vmap and just tried implement the loss_fn function. It returned me the 64 losses.
The error only pops when I use this line vmap(grad(loss_fn, argnums=(0)), in_dims=(None, 0, 0))(params, inputs, labels) #torch.func.grad for reference.
Here is my test model.
class SimpleCNN(nn.Module):
def init(self):
super().init()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
vmap removes the batch dimension, so you’ll need to use relative dims within torch functions. For example, a tensor of shape [B, N] would be viewed as shape [N] by vmap within your torch.flatten(x,1) function.
The output size of the intermediate vector is of a different size than the input nodes in the self.fc1 layer as stated in the error message. You can fix this by changing the number of inputs nodes in self.fc1.
You also need to change the torch.flattendim arg to be relative rather than absolute, as the batch-size is removed when vmapping, an example code can be found below.
import torch
from torch import nn
from torch.func import vmap, jacrev, functional_call
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(64, 120) #<<< wrong amount of input nodes (changed to 64)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, -3) #changed from 1 to -3 (relative dims, aka 3rd from the right rather than 1st from left)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = SimpleCNN() #the network
def fcall(params, x):
return functional_call(net, params, x)
params = dict(net.named_parameters())
x = torch.randn(4, 3, 20, 20) #random input size
y = net(x)
grad_params = vmap(jacrev(fcall, argnums=(0)), in_dims=(None,0))(params, x)
print(grad_params) #returns per-sample gradients with respect to params