Compute expected transformation of gradient

Ah yes, those operations won’t commute. In that case, an efficient way of doing would be per-sample gradients.

There are 2 ways to compute per-sample gradients

1.Via hooks (discussion here)
2. FuncTorch (Repo here)

As you’re wanting to apply a transformation, FuncTorch will probably be best way as that allows for higher-order gradients of per-sample gradients whereas the hooks method is purely to get the gradients and you won’t be able to differentiate them if your transformation requires it, you’ll also need to define some manual derivatives too which can get messy).