Param.grad is always None

def calculate_fisher_matrix_3(task_model, train_loader, criterion, feature_extractor, num_samples=1000):
task_model.train()
optimizer = optim.Adam(model.parameters(), lr=0.001)
fisher_matrix = [torch.zeros_like(param.data) for param in task_model.parameters()]
total_samples = 0

for i, data in enumerate(train_loader, 0):
    inputs = data["pixel_values"].requires_grad_(True)
    labels = data["labels"].float().requires_grad_(True)

    sizei = data["pixel_values"].shape[2], data["pixel_values"].shape[3]

    optimizer.zero_grad()

    # Forward pass
    logits = task_model(inputs)[0]
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=sizei,  # (height, width)
        mode='bilinear',
        align_corners=False
    )
    logits = upsampled_logits.argmax(dim=1)
    logits = logits.float()

    # Compute loss
    loss = criterion(logits, labels)

    # Compute gradients for fisher_matrix
    loss.retain_grad()
    loss.backward()

    for i, param in enumerate(task_model.parameters()):
        if param.grad is not None:
            fisher_matrix[i] += param.grad.detach() ** 2

    # Reset gradients for the next batch
    task_model.zero_grad()

    total_samples += len(inputs)

    if total_samples >= num_samples:
        break

for i in range(len(fisher_matrix)):
    fisher_matrix[i] /= total_samples

return fisher_matrix

I want to compute FIM, However its not updating param.grad, Please help