Compute expected transformation of gradient

Consider an additive loss (e.g., MSE) of the form F(y_hat, y) = sum f(y_hat[i], y[i]). We can think of this as an expectation over the empirical distribution induced by the training set. That is, the above can also be written as the expectation E[ f(Y_hat, Y) ].

When we execute the following code

optim.zero_grad()
loss.backward()
optim.step()

the result is to step in the negative direction of grad F(y_hat, y) = sum grad f(y_hat[i], y[i]). Again, in expectation notation, this is E[ grad f(Y_hat, Y) ].

I would like to compute the expectation of the transformation T of the gradient. That is, E[ T(grad f(Y_hat, Y)) ]. How can I do this in PyTorch?

What do you mean by transformation T? If it’s an arbitrary function, surely you could just apply that to the individual losses pre-expectation, then backpropagate and you’d get what you want?

Also, your resultant expectation is only valid if the training data is independent of your model! (In most cases this is true, but something to be aware of in case it isn’t in the future)

Yes, I would like to apply T before the expectation is taken, but in PyTorch and other deep learning frameworks, gradients are aggregated over samples. So you can’t do what you suggest (at least not without some elbow grease). That’s effectively what I’m asking here: how/where do I apply the elbow grease to achieve this?

To convince you that this is non-standard, what I am describing has been implemented for TensorFlow here: GitHub - ProbabilisticNumerics/cabs: Tensorflow implementation of SGD with Coupled Adaptive Batch Size (CABS)

So, if loss is defined as the sum of some individual loss values then would you be able to apply T there before you define loss? So,

x = torch.randn(10, 2) #batch size = 10
yhat = model(x) #R^2 -> R^1 function represented by ANN
y = torch.randn(10, 1) #target values 

individual_losses = (y - yhat).pow(2)
individual_losses_transformed = T(individual_losses)
loss = individual_losses_transformed.mean()

#Then call optimizer...
optim.zero_grad()
loss.backward()
optim.step()

So along the lines of define your own custom MSELoss function and let autograd handle the rest?

Your code is computing the expected gradient of the transformation, not the expected transformation of the gradient.

In other words, your code computes

E[ grad T ( f ( Y_hat, Y ) ) ]

I want

E[ T ( grad f( Y_hat, Y ) ) ] 

(The order of operations differs in the above two)

Ah yes, those operations won’t commute. In that case, an efficient way of doing would be per-sample gradients.

There are 2 ways to compute per-sample gradients

1.Via hooks (discussion here)
2. FuncTorch (Repo here)

As you’re wanting to apply a transformation, FuncTorch will probably be best way as that allows for higher-order gradients of per-sample gradients whereas the hooks method is purely to get the gradients and you won’t be able to differentiate them if your transformation requires it, you’ll also need to define some manual derivatives too which can get messy).

Thank you. I will check that out and get back to you.