Param.grad is always None

def calculate_fisher_matrix_3(task_model, train_loader, criterion, feature_extractor, num_samples=1000):
    task_model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    fisher_matrix = [torch.zeros_like(param.data) for param in task_model.parameters()]
    total_samples = 0

    for i, data in enumerate(train_loader, 0):
        inputs = data["pixel_values"].requires_grad_(True)
        labels = data["labels"].float().requires_grad_(True)

        sizei = data["pixel_values"].shape[2], data["pixel_values"].shape[3]

        optimizer.zero_grad()

        # Forward pass
        logits = task_model(inputs)[0]
        upsampled_logits = nn.functional.interpolate(
            logits,
            size=sizei,  # (height, width)
            mode='bilinear',
            align_corners=False
        )
        logits = upsampled_logits.argmax(dim=1)
        logits = logits.float()

        # Compute loss
        loss = criterion(logits, labels)

        # Compute gradients for fisher_matrix
        loss.retain_grad()
        loss.backward()

        for i, param in enumerate(task_model.parameters()):
            if param.grad is not None:
                fisher_matrix[i] += param.grad.detach() ** 2

        # Reset gradients for the next batch
        task_model.zero_grad()

        total_samples += len(inputs)

        if total_samples >= num_samples:
            break

    for i in range(len(fisher_matrix)):
        fisher_matrix[i] /= total_samples

    return fisher_matrix

I want to compute FIM, However its not updating param.grad, Please help