Differentiate weight gradients with respect to labels

Hey all, I’m wondering if it’s possible to compute the derivative of updated model parameters with respect to the labels, more concretely:

dw_{t+1}/dy_t = d/dy_t(w_t - \eta*dL/dw_t) = -\eta*d/dy_t(dL/dw_t)

where w_{t+1} are the updated weights after computing loss at time step t with some learning rate \eta. This should be possible as the loss at time t is a function of y_t.

When I try to do this with autograd.grad() the shape of the gradients is the same as the shape of my labels, when it should also incorporate the shape of the weights.

What am I doing wrong?

Here’s some example code:

import torch
from torch import nn, optim
from torch.autograd import grad

class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(32,16),
            nn.ReLU()
        )
        self.layer2 = nn.Sequential(
            nn.Linear(16,8),
            nn.ReLU()
        )
        self.out = nn.Linear(8,3)
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.out(x)
        return x
    
x = torch.randn(32, 32)
y = torch.nn.functional.one_hot(torch.randint(3, (32,)))
y = torch.tensor(y, dtype = float, requires_grad = True)

model = Model()
loss = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr = .001)

opt.zero_grad()
logits = model(x)
error = loss(logits, y)
error.backward(retain_graph = True, create_graph = True)

dw_dy = grad(model.layer2[0].weight.grad, y, torch.ones_like(model.layer2[0].weight.grad))[0]

assert dw_dy.shape == y.shape