Calculate gradient with respect to data label

Hi everyone,

I’m working on implementing a technique from a research paper, FedML-HE where I need to calculate gradients with respect to data labels. Here’s the approach I’m following:

Calculate the gradient of the loss with respect to the model’s parameters (grad1).
Calculate the gradient of grad1 with respect to the data labels.
However, I’m encountering an issue where the gradient with respect to the data labels is always None. It seems that the data labels (y) are not part of the computational graph (y.grad_fn returns None).

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.nn.utils import parameters_to_vector

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Load CIFAR10 dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True)

model = LeNet()
criterion = nn.CrossEntropyLoss()

# Get a batch of data
x, y = next(iter(trainloader))

y = y.float().requires_grad_(True)

output = model(x)
loss = criterion(output, y.long())

first_order_grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)

Jm_list = []

for grad in first_order_grads:
    if grad is not None:
        for grad_element in parameters_to_vector(grad): 
            Jm = torch.autograd.grad(grad_element, y, retain_graph=True, allow_unused=True)[0]
            Jm_list.append(Jm)

I’m looking for advice on two points:

How can I include data labels in the computational graph?
Are there alternative methods to calculate gradients with respect to data labels effectively?
Any insights or references to similar implementations would be greatly appreciated!

This is expected since the target wasn’t used in the forward pass and no trainable parameters of the model were used to create it. Thus changing the parameters in any direction will not influence the target at all, which is represented by the None gradient.

Could you explain your use case in more detail and why you are expecting to see gradients?

Thank you for the explanation for why taking derivative to the data label always returns None.

The main idea from the research paper that discusses the sensitivity of model gradients in response to changes in input data labels. The paper suggests that a larger gradient in response to a particular data label indicates a higher sensitivity of the model to that specific input.

Here’s what I understand: The sensitivity analysis aims to quantify how much the gradient (or direction and magnitude of model learning) will change with the data label for each data point, which to identifies which sensitive the model parameters to the training data.

The way I would do it is

y = y.long()
y.requires_grad = True
yhat = model(x)
loss = criterion(yhat, y)
loss.backward()
print(y.grad)

Also I am not exactly sure about this, but you might also need to add this:

yopt = torch.optim.SGD([y])

just to make it so that pytorch calculates gradients for y. But maybe this isn’t needed.

Would this mean that you would like to compute the gradients of all trainable parameters w.r.t. loss0 (let’s call these gradients g0) and compare them to gradients computed w.r.t. to another loss1 (let’s call these g1)?
If so, I don’t think the target itself has to be differentiable, but it sounds as if you want to change the labels allowing you to recompute the gradients of model parameters.

Let me know, if I misunderstood your use case.

Thanks for replying.

The gradients I’m trying to calculate after getting the loss from the model prediction and ground truth label are:

  1. The gradients of the loss w.r.t. all model parameters as g0
  2. The gradients of g0 w.r.t. the label as g1. Because the ground truth labels are not used in the forward pass, now I modify g1 as the gradients of g0 w.r.t. the model prediction.

Therefore, g0 would be regular backward propagation. Next, g1 would give how much each gradient responses to the model output.

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Load CIFAR10 dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True)

model = LeNet()
criterion = nn.CrossEntropyLoss()

# Get a batch of data
x, y = next(iter(trainloader))

output = model(x)
loss = criterion(output, y.long())

for param_layer in model.parameters():
    g0= torch.autograd.grad(outputs=loss, inputs=param_layer, create_graph=True)[0]
    print(f"    g0 size: {g0.shape}")

    for param in parameters_to_vector(g0):
        # g1 = torch.autograd.grad(outputs=param, inputs=y, create_graph=True, allow_unused=True)[0] # replace y as output because y is not used in the model
        g1 = torch.autograd.grad(outputs=param, inputs=output, create_graph=True, allow_unused=True)[0]

Thank you for the reply.

y.grad would give me None values because it’s not used in the model except for calculating the loss.

Hi Rahn!

The short story is pass your labels as a FloatTensor to
CrossEntropyLoss. Specifically, if output is a FloatTensor with
shape [nBatch, nClass], then y should also have that shape,
that is, have an explicit class dimension.

Your use of CrossEntropyLoss and the conversion of y to y.float()
suggests that y consists of integer categorical class labels. Because
these are discrete, they are inherently not differentiable.

Converting to float, setting requires_grad = True, and converting
back to long doesn’t fix this. (The result of y.long() no longer
carries requires_grad = True.)

But yes, once you address your issue of discrete, non-differentiable
labels and how you use CrossEntropyLoss, your approach to
computing your desired second partial derivative is basically correct.

CrossEntropyLoss has two modes: In the first, you supply a single
integer class label per sample (so your labels would be a LongTensor
with no class dimension). In the second, you supply a set of class
probabilities (so your labels are a FloatTensor with an explicit class
dimension). Such class probabilities are sometimes called “soft” or
“probabilistic” labels.

In the simplest approach you could:

y = torch.nn.functional.one_hot (y, nClass).float().requires_grad_ (True)

y will now have an explicit class dimension and you will be able to
compute gradients with respect to y.

There are some nuances (which may or may not matter for your use
case). y should be understood as a (discrete) probability distribution
over classes. Therefore each element should be between 0.0 and
1.0 and y should sum to 1.0. (The result of one_hot() satisfies
this.)

But when you think about varying an element of y so as to compute
a gradient with respect to y, it doesn’t necessarily make logical sense
to have an element of y be exactly equal to 0.0 or 1.0, because
if it were to vary below 0.0 or above 1.0, the resulting y would fail
(by an infinitesimal amount) to be a valid probability distribution.

In a similar vein, if you vary a single element of y in isolation, y will
no longer sum to 1.0 and again not be a valid probability distribution.

Last, the “sensitivities” of your parameter gradients with respect to your
labels are just the second partial derivatives of your loss with respect
to your model parameters and your labels. (These form a subset of the
full “hessian” of your loss with respect to the parameters and labels.)

You’re allowed, if you so choose, to first compute the gradient of the
loss with respect to the labels, and then compute the jacobian of that
gradient with respect to the model parameters. (That is to say, the
hessian is symmetric.) It turns out to be a little more convenient to
compute the gradients is that order (but you can do it either way).

Here is a script that shows how to compute your desired sensitivities
(computing the label gradient first and and accounting for the nuances
due to the “soft” labels being a probability distribution):

import torch
print (torch.__version__)

_ = torch.manual_seed (2024)

nClass = 5

model = torch.nn.Linear (3, nClass)
inp = torch.randn (3)

prd = model (inp)

targetIndex = torch.tensor (2)                                             # target as integer categorical index
targetOneHot = torch.nn.functional.one_hot (targetIndex, nClass).float()   # target as "soft" probabilities

lossIndex = torch.nn.CrossEntropyLoss() (prd, targetIndex)
lossOneHot = torch.nn.CrossEntropyLoss() (prd, targetOneHot)

print ('torch.equal (lossIndex, lossOneHot):', torch.equal (lossIndex, lossOneHot))

def softHot (prob, ind, nc):                                               # soft target not saturated at 0 and 1
    pLo = (1.0 - prob) / (nc - 1)
    soft = torch.full ((nc,), pLo)
    soft[ind] = prob
    return  soft

targetSoftHot = softHot (0.95, targetIndex, nClass)

lossSoftHot = torch.nn.CrossEntropyLoss() (prd, targetSoftHot)

print ('lossOneHot: ', lossOneHot)                                         # the two loss values are similar
print ('lossSoftHot:', lossSoftHot)

targetSoftHot.requires_grad = True
targetNorm = targetSoftHot / targetSoftHot.sum()                           # affects grad, so not a no-op

loss = torch.nn.CrossEntropyLoss() (prd, targetNorm)

grad = torch.autograd.grad (loss, targetSoftHot, create_graph = True)[0]
print ('grad.shape:', grad.shape)

hessW = []
for  g in grad:
    hessW.append (torch.autograd.grad (g, model.weight, retain_graph = True)[0])

hessW = torch.stack (hessW)
print ('hessW.shape:', hessW.shape)

hessB = []
for  g in grad:
    hessB.append (torch.autograd.grad (g, model.bias, retain_graph = True)[0])

hessB = torch.stack (hessB)
print ('hessB.shape:', hessB.shape)

print ('hessB = ...')                                                      # second partials of loss with respect to
print (hessB)                                                              # target and bias, 5 elements each

And here is its output:

2.3.1
torch.equal (lossIndex, lossOneHot): True
lossOneHot:  tensor(1.4802, grad_fn=<DivBackward1>)
lossSoftHot: tensor(1.4942, grad_fn=<DivBackward1>)
grad.shape: torch.Size([5])
hessW.shape: torch.Size([5, 5, 3])
hessB.shape: torch.Size([5, 5])
hessB = ...
tensor([[-0.9875,  0.0125,  0.9500,  0.0125,  0.0125],
        [ 0.0125, -0.9875,  0.9500,  0.0125,  0.0125],
        [ 0.0125,  0.0125, -0.0500,  0.0125,  0.0125],
        [ 0.0125,  0.0125,  0.9500, -0.9875,  0.0125],
        [ 0.0125,  0.0125,  0.9500,  0.0125, -0.9875]])

Best.

K. Frank