There are some applications requiring per-sample gradient (not a mini-batch gradient). Let’s say we need it.
There is a few ways to do it with autograd:
- Call it many times, each for a single loss in a mini-batch. This is slow.
- Using Goodfellow’s method http://arxiv.org/abs/1510.01799 which is basically multiplying the hidden state’s gradients with the input to the layer to complete the gradient.
The idea of 2) is efficient because we do only necessary computation, however we need to manually code the derivative of the output of a layer wrt. its weights for each layer manually (autograd will not do that step for us).
There is a way to allow for autograd end-to-end which is to design the layer “a bit differently”.
The problem with the current design of a layer and per-sample gradient is that the weight is “shared” among samples within a mini-batch. Had it not shared, we can compute gradient wrt. each weights for each sample. This is equivalent to the Goodfellow’s method computation-wise.
Example for a linear layer:
Traditionally we define a linear layer:
x (batch, features) w (in_features, out_features) torch.einsum('ni,ij->nj', x, w)
A revised version would be:
x (batch, features) w (in_features, out_features) ww = w.expand(batch, in_features, out_features) ww.retain_grad() y = torch.einsum('ni,nij->nj', x, ww)
We will now get the gradient
ww.grad which has the shape (batch, in_features, out_features), per-sample gradient.
- Memory footprint is about the same as mini-batch gradient except for the storage for the per-sample gradient for each weights, is it not?
- Computation-wise they would be the same as mini-batch gradient, we just save ourselves the last sum-reduction step?
- Are there any roadblocks in terms of implementation such that it would be much less efficient memory-wise or computation-wise?