Per-sample gradient, should we design each layer differently?

There are some applications requiring per-sample gradient (not a mini-batch gradient). Let’s say we need it.

There is a few ways to do it with autograd:

  1. Call it many times, each for a single loss in a mini-batch. This is slow.
  2. Using Goodfellow’s method http://arxiv.org/abs/1510.01799 which is basically multiplying the hidden state’s gradients with the input to the layer to complete the gradient.

The idea of 2) is efficient because we do only necessary computation, however we need to manually code the derivative of the output of a layer wrt. its weights for each layer manually (autograd will not do that step for us).

There is a way to allow for autograd end-to-end which is to design the layer “a bit differently”.

The problem with the current design of a layer and per-sample gradient is that the weight is “shared” among samples within a mini-batch. Had it not shared, we can compute gradient wrt. each weights for each sample. This is equivalent to the Goodfellow’s method computation-wise.

Example for a linear layer:

Traditionally we define a linear layer:

x (batch, features)
w (in_features, out_features)
torch.einsum('ni,ij->nj', x, w)

A revised version would be:

x (batch, features)
w (in_features, out_features)
ww = w.expand(batch, in_features, out_features)
ww.retain_grad()
y = torch.einsum('ni,nij->nj', x, ww)

We will now get the gradient ww.grad which has the shape (batch, in_features, out_features), per-sample gradient.

Questions

  • Memory footprint is about the same as mini-batch gradient except for the storage for the per-sample gradient for each weights, is it not?
  • Computation-wise they would be the same as mini-batch gradient, we just save ourselves the last sum-reduction step?
  • Are there any roadblocks in terms of implementation such that it would be much less efficient memory-wise or computation-wise?

Hi,

The main roadblock to introduce this in the core library is that even though this works for some layers, it cannot be generalized easily to all operations supported by the autograd.

If you use these operations though, this package by @Yaroslav_Bulatov is what I would use.

Per-example and mean-gradient calculations work on the same set of inputs, so PyTorch autograd already gets you 90% of the way there. For an illustration: consider the problem of backprop in a simple feedforward architecture.

Each layer has a Jacobian matrix and we get the derivative by multiplying them out.

image

While any order of matrix multiplication is valid, reverse differentiation specializes to multiplying them “left to right”.

For “left-to-right” order, each op implements “vector-Jacobian” product, called “backward” in PyTorch, “grad” in TensorFlow and “Lop” in Theano. This is done without forming Jacobian matrix explicitly, necessary for large scale applications.

When dealing with non-linear functions, each example in a batch corresponds to a different Jacobian, so our backward functions do this in a batch fashion.

Autograd engine traverses the graph and feeds a batch of these vectors (backprops) into each op to obtain backprops for the downstream op.

To see how per-example calculation works, note that for matmul, parameter gradient is equivalent to computing a batch of outer products of matching activation/backprop vector pairs, then summing over the batch dimension. Activations are values fed into the layer during forward pass.

We get the per-example gradients by dropping the sum
image

Basically it’s a matter of replacing
grad=torch.einsum('ni,nj->ij', backprops, activations)
with
grad1=torch.einsum('ni,nj->nij', backprops, activations)

Because grad1 calculation doesn’t affect downstream tasks, you only need to implement this for layers that have parameters. The autograd-hacks lib does this for Conv2d and Linear. To extend to a new layer, you would look for the “leaf” ops, look at their “backwards” implementation and figure out how to drop the “sum over batch” part.

I have been using your library, and I’m thankful for that. It is still the case that some layers are not trivial to implement the derivative. I’m thinking of the likes of LSTM.

My proposed implementation should match yours (which is derived from Goodfellow’s) for all steps of activations. Our methods only differ at the last steps where yours needs to code the derivative while mine doesn’t need that.

@albanD Could you elaborate a bit more on the roadblocks you mentioned? Anyway, could you comment on the memory footprint and computation requirement as well?

I looked at your example more carefully and expand seems like a neat trick.

It’s simpler, and avoids capturing an extra reference to activations like I do in autograd_hacks.
But extra memory usage means you can’t make this the default – big models are already memory bottlenecked and gradients are large. IE, GPT-2 gradient is 6GB per example.

1 Like

I don’t remember, but I don’t think convolution ops accept expanded weight, @Konpat_Ta_Preechakul ?

As we discussed with Yaroslav before. The roadblock is that this accumulation is a property of the op you are doing, not of the autograd. You could be doing many operations on your weights before passing it to the Linear layer (for example people that enforce some structure on weights). In that case, it’s not only the last layer for which you need this special function, but for all the layers before this one.

Don’t be annoyed yet. I still can’t see why convolution would not be compatible with expanded weight. Current implementation sure it might not. The proposal very likely requires re-implementation of all layers anyway.

The “per-example gradient” computation needs to distinguish between “data” and “parameter” inputs, so it’s natural to handle them at the level of layers (ie, subclasses of nn.Module) rather than at the level of ops.

A practical limitation is that PyTorch ops often don’t support batch dimension for the weight input. For example, https://arxiv.org/abs/1906.02506 needed to compute loss over a batch of weights in PyTorch. The solution had to use 8 copies of model for a batch of size 8 (one per GPU).

1 Like

“per-example gradient” computation needs to distinguish between “data” and “parameter” inputs

Why is that the case?

A practical limitation is that PyTorch ops often don’t support batch dimension for the weight input.

This limitation doesn’t seem like a roadblock. It seems like something that a re-implementation could solve.

So that you know which Tensors to call “expand” on :slight_smile:
In other words, this doesn’t need to change the autograd engine, just the layer implementations.

Changing all layers to do this by default would break things because of extra memory usage. But you could create custom layers that extend existing layer with “per-example-gradient-capturing” behavior using the expand trick.

Adding for Conv2d weight batching seems like a generally useful feature to have. “groups” feature already implements something similar because each group has its own set of filters.

1 Like

I agree with you. There is nothing theoretical preventing us from doing it.
I said it was a roadblock in the sense that it will be time consuming to implement and unlikely to be high priority enough to be done now by the core team.

That being said, we will be very happy to accept a PR that adds support for weights that have an extra dimension for the batch size.
If you do have time to work on this, you can open an issue that explains how you plan to add this and I will be more than happy to help have a good design for this and then review the implementation !

The tricky parts/questions that need answering are:
backend:

  • Fast conv for weights with an extra batch dimension. like mm/bmm, have conv/bconv
  • Check other commonly used Module (like batchnorm) to make sure the underlying implementation will support batched weights

frontend:

  • Do you need the ability to switch between the two modes?
  • Do we want this to be built in existing nn.Modules? (if so, how do we handle possibly changing batch-sizes)
  • Do we want element-wise versions of the existing modules as new Modules?
1 Like

Based on your intuitions, do you think it would be theoretically slower? For example, this batched weights might prevent low-level optimization of some kind. I have no experience in low-level GPU programming. :sweat:

Well you could play tricks and if you detect that the weights are indeed only expanded, you can use the original forward and so not have any performance loss.
For both the case where the weights are actually different and for the backward, I have to admit I don’t know. But there will be some performance drop because you won’t be able to use the cudnn optimized kernels I’m afraid.

That being said, it is expected to be more expensive to compute element-wise gradients in most cases.

1 Like

Replacing single matmul with 2 einsums could be slower because of difference in speeds in GPU memory types. You have fast SRAM used for intermediate results and slow GPU memory for main storage. For low arithmetic intensity operations on V100, HBM2 can’t keep up with the compute units and becomes the bottleneck. For instance, you can’t do “torch.einsum(‘ni,nij->nj’, x, ww)” at nearly the same FLOPS as regular matmul or conv.

1 Like

That is precisely what I saw!