def calculate_fisher_matrix_3(task_model, train_loader, criterion, feature_extractor, num_samples=1000):
task_model.train()
optimizer = optim.Adam(model.parameters(), lr=0.001)
fisher_matrix = [torch.zeros_like(param.data) for param in task_model.parameters()]
total_samples = 0
for i, data in enumerate(train_loader, 0):
inputs = data["pixel_values"].requires_grad_(True)
labels = data["labels"].float().requires_grad_(True)
sizei = data["pixel_values"].shape[2], data["pixel_values"].shape[3]
optimizer.zero_grad()
# Forward pass
logits = task_model(inputs)[0]
upsampled_logits = nn.functional.interpolate(
logits,
size=sizei, # (height, width)
mode='bilinear',
align_corners=False
)
logits = upsampled_logits.argmax(dim=1)
logits = logits.float()
# Compute loss
loss = criterion(logits, labels)
# Compute gradients for fisher_matrix
loss.retain_grad()
loss.backward()
for i, param in enumerate(task_model.parameters()):
if param.grad is not None:
fisher_matrix[i] += param.grad.detach() ** 2
# Reset gradients for the next batch
task_model.zero_grad()
total_samples += len(inputs)
if total_samples >= num_samples:
break
for i in range(len(fisher_matrix)):
fisher_matrix[i] /= total_samples
return fisher_matrix
I want to compute FIM, However its not updating param.grad, Please help