Per-sample gradient, should we design each layer differently?

There are some applications requiring per-sample gradient (not a mini-batch gradient). Let’s say we need it.

There is a few ways to do it with autograd:

  1. Call it many times, each for a single loss in a mini-batch. This is slow.
  2. Using Goodfellow’s method http://arxiv.org/abs/1510.01799 which is basically multiplying the hidden state’s gradients with the input to the layer to complete the gradient.

The idea of 2) is efficient because we do only necessary computation, however we need to manually code the derivative of the output of a layer wrt. its weights for each layer manually (autograd will not do that step for us).

There is a way to allow for autograd end-to-end which is to design the layer “a bit differently”.

The problem with the current design of a layer and per-sample gradient is that the weight is “shared” among samples within a mini-batch. Had it not shared, we can compute gradient wrt. each weights for each sample. This is equivalent to the Goodfellow’s method computation-wise.

Example for a linear layer:

Traditionally we define a linear layer:

x (batch, features)
w (in_features, out_features)
torch.einsum('ni,ij->nj', x, w)

A revised version would be:

x (batch, features)
w (in_features, out_features)
ww = w.expand(batch, in_features, out_features)
ww.retain_grad()
y = torch.einsum('ni,nij->nj', x, ww)

We will now get the gradient ww.grad which has the shape (batch, in_features, out_features), per-sample gradient.

Questions

  • Memory footprint is about the same as mini-batch gradient except for the storage for the per-sample gradient for each weights, is it not?
  • Computation-wise they would be the same as mini-batch gradient, we just save ourselves the last sum-reduction step?
  • Are there any roadblocks in terms of implementation such that it would be much less efficient memory-wise or computation-wise?

Hi,

The main roadblock to introduce this in the core library is that even though this works for some layers, it cannot be generalized easily to all operations supported by the autograd.

If you use these operations though, this package by @Yaroslav_Bulatov is what I would use.

1 Like

Per-example and mean-gradient calculations work on the same set of inputs, so PyTorch autograd already gets you 90% of the way there. For an illustration: consider the problem of backprop in a simple feedforward architecture.

Each layer has a Jacobian matrix and we get the derivative by multiplying them out.

image

While any order of matrix multiplication is valid, reverse differentiation specializes to multiplying them “left to right”.

For “left-to-right” order, each op implements “vector-Jacobian” product, called “backward” in PyTorch, “grad” in TensorFlow and “Lop” in Theano. This is done without forming Jacobian matrix explicitly, necessary for large scale applications.

When dealing with non-linear functions, each example in a batch corresponds to a different Jacobian, so our backward functions do this in a batch fashion.

Autograd engine traverses the graph and feeds a batch of these vectors (backprops) into each op to obtain backprops for the downstream op.

To see how per-example calculation works, note that for matmul, parameter gradient is equivalent to computing a batch of outer products of matching activation/backprop vector pairs, then summing over the batch dimension. Activations are values fed into the layer during forward pass.

We get the per-example gradients by dropping the sum
image

Basically it’s a matter of replacing
grad=torch.einsum('ni,nj->ij', backprops, activations)
with
grad1=torch.einsum('ni,nj->nij', backprops, activations)

Because grad1 calculation doesn’t affect downstream tasks, you only need to implement this for layers that have parameters. The autograd-hacks lib does this for Conv2d and Linear. To extend to a new layer, you would look for the “leaf” ops, look at their “backwards” implementation and figure out how to drop the “sum over batch” part.

I have been using your library, and I’m thankful for that. It is still the case that some layers are not trivial to implement the derivative. I’m thinking of the likes of LSTM.

My proposed implementation should match yours (which is derived from Goodfellow’s) for all steps of activations. Our methods only differ at the last steps where yours needs to code the derivative while mine doesn’t need that.

@albanD Could you elaborate a bit more on the roadblocks you mentioned? Anyway, could you comment on the memory footprint and computation requirement as well?

I looked at your example more carefully and expand seems like a neat trick.

It’s simpler, and avoids capturing an extra reference to activations like I do in autograd_hacks.
But extra memory usage means you can’t make this the default – big models are already memory bottlenecked and gradients are large. IE, GPT-2 gradient is 6GB per example.

1 Like

I don’t remember, but I don’t think convolution ops accept expanded weight, @Konpat_Ta_Preechakul ?

As we discussed with Yaroslav before. The roadblock is that this accumulation is a property of the op you are doing, not of the autograd. You could be doing many operations on your weights before passing it to the Linear layer (for example people that enforce some structure on weights). In that case, it’s not only the last layer for which you need this special function, but for all the layers before this one.

Don’t be annoyed yet. I still can’t see why convolution would not be compatible with expanded weight. Current implementation sure it might not. The proposal very likely requires re-implementation of all layers anyway.

The “per-example gradient” computation needs to distinguish between “data” and “parameter” inputs, so it’s natural to handle them at the level of layers (ie, subclasses of nn.Module) rather than at the level of ops.

A practical limitation is that PyTorch ops often don’t support batch dimension for the weight input. For example, https://arxiv.org/abs/1906.02506 needed to compute loss over a batch of weights in PyTorch. The solution had to use 8 copies of model for a batch of size 8 (one per GPU).

1 Like

“per-example gradient” computation needs to distinguish between “data” and “parameter” inputs

Why is that the case?

A practical limitation is that PyTorch ops often don’t support batch dimension for the weight input.

This limitation doesn’t seem like a roadblock. It seems like something that a re-implementation could solve.

So that you know which Tensors to call “expand” on :slight_smile:
In other words, this doesn’t need to change the autograd engine, just the layer implementations.

Changing all layers to do this by default would break things because of extra memory usage. But you could create custom layers that extend existing layer with “per-example-gradient-capturing” behavior using the expand trick.

Adding for Conv2d weight batching seems like a generally useful feature to have. “groups” feature already implements something similar because each group has its own set of filters.

1 Like

I agree with you. There is nothing theoretical preventing us from doing it.
I said it was a roadblock in the sense that it will be time consuming to implement and unlikely to be high priority enough to be done now by the core team.

That being said, we will be very happy to accept a PR that adds support for weights that have an extra dimension for the batch size.
If you do have time to work on this, you can open an issue that explains how you plan to add this and I will be more than happy to help have a good design for this and then review the implementation !

The tricky parts/questions that need answering are:
backend:

  • Fast conv for weights with an extra batch dimension. like mm/bmm, have conv/bconv
  • Check other commonly used Module (like batchnorm) to make sure the underlying implementation will support batched weights

frontend:

  • Do you need the ability to switch between the two modes?
  • Do we want this to be built in existing nn.Modules? (if so, how do we handle possibly changing batch-sizes)
  • Do we want element-wise versions of the existing modules as new Modules?
1 Like

Based on your intuitions, do you think it would be theoretically slower? For example, this batched weights might prevent low-level optimization of some kind. I have no experience in low-level GPU programming. :sweat:

Well you could play tricks and if you detect that the weights are indeed only expanded, you can use the original forward and so not have any performance loss.
For both the case where the weights are actually different and for the backward, I have to admit I don’t know. But there will be some performance drop because you won’t be able to use the cudnn optimized kernels I’m afraid.

That being said, it is expected to be more expensive to compute element-wise gradients in most cases.

1 Like

Replacing single matmul with 2 einsums could be slower because of difference in speeds in GPU memory types. You have fast SRAM used for intermediate results and slow GPU memory for main storage. For low arithmetic intensity operations on V100, HBM2 can’t keep up with the compute units and becomes the bottleneck. For instance, you can’t do “torch.einsum(‘ni,nij->nj’, x, ww)” at nearly the same FLOPS as regular matmul or conv.

1 Like

That is precisely what I saw!

Hi everyone, and sorry for opening a topic which is over 2 years old.

I was wondering if I could ask a quick question about what precisely the per-sample gradient is here?

What I exactly mean by this is you’re converting the gradient of the layer to get the gradient for all samples by using the grad_output and activations of a given layer, and using torch.einsum to combine the two.

But, I’m not 100% sure what exactly the per-sample gradient is. To illustrate my point, let’s assume I have a set of inputs of shape [B, A] where B is the batch size, and A is the number of inputs. (For example with a simple Feed-Forward Network). Where I would pass these inputs through a loss function (containing a network of parameters) which returns a Tensor of shape [B,] (where each element of this Tensor is effectively an individual loss for each sample in the batch).

If I define the (total) loss as the mean of these individual loss values (so, reduce the Tensor of shape [B,] to a scalar), and backprop that loss via loss.backward() while using the hooks (as mentioned above) to get the gradients of the parameters for all samples.

Would these gradients be either,

  1. the gradient of the (total) loss w.r.t the parameters for the i-th sample within the batch for all samples

  2. the gradient of the individual loss of the i-th sample w.r.t the parameters for all samples

The only reason why I’m asking for this clarification is that I’m trying to get the 2nd case here where I can calculate the gradient of the individual loss values w.r.t the parameters of a nn.Module for all samples.

An example use case would be the KFAC optimizer where you need to compute the exact-fisher for all samples when rescaling your preconditioned gradients.

I’ve managed to make an example code comparing the 2 methods,

  1. using hooks to get the gradient of the (total) loss of the i-th sample for all samples (in a similar way to autograd-hacks)

  2. iterate over the batch and compute the gradient of the individual loss w.r.t the parameter for each sample sequentially and store within a list, which is subsequently stacked to get the same shape as the Tensor returned via the hooks method (which is incredibly slow).

Both these methods return different values which would indicate 2 things. Either I’ve possibly made a mistake somewhere in my code or the gradients returned via the hooks is the grad of the total-loss w.r.t parameters for all samples rather than the grad of the individual losses w.r.t the parameters. I should note all my samples are independent of each other when passed through the loss function!

Apologizes for the long reply for a 2-year-old post but as this isn’t really mentioned within the docs I’d just like to clarify some points!

Any help would be greatly appreciated!

Thank you for your time,
Kind regards!

TL;DR - Are the gradients for all samples which are calculated via this autograd-hacks method the gradient of the total loss with respect to the parameters for all samples or the gradient of each individual loss with respect to the parameters for all samples? (where individual loss means, the loss of a single sample before its reduction to a scalar loss. So, a batch of B samples would have B individual losses.)

Thanks once again! :slight_smile:

A batch of size k with k per-example gradients is numerically equivalent to iterating over k examples using batch size 1 and saving k gradients you get.

This mini-unit test in the docs of autograd-hacks should additionally clarify the relationship between per-example gradients and the “gradient”

# param.grad: gradient averaged over the batch
# param.grad1[i]: gradient with respect to example i

for param in model.parameters():
  assert(torch.allclose(param.grad1.mean(dim=0), param.grad))

Note that KFAC implementations typically inject isotropic random vectors at the output layer (ie, call output.backward(random_vec)) and compute factored approximation of resulting covariance matrix of gradients. This (in expectation) is equal to Gauss-Newton matrix, with softmax loss replaced by least-squares loss, rather than Fisher matrix.

Problem with Fisher matrix (and Gauss-Newton with softmax) is that once you start converging, those two matrices go to zero, so preconditioning by such matrix is a division by zero.

I have an example of computing KFAC factors in a slightly cleaned-up autograd-lib here:

1 Like

Hi @Yaroslav_Bulatov!

Thank you for the very quick and detailed response!

I assume the reason behind why they’re equivalent is because the gradient of mean(L_i) w.r.t parameter_j is equal to the gradient of mean(L_i) w.r.t L_i’ * gradient L_i’ w.r.t parameter_j and the first gradient term is just 1, so they’re equivalent? (I assume this only works for independent samples?)

Thank you for the explanation between per-sample gradient and “gradient”, I think I fully understand it now!

The follow-up question I would have, if that’s ok, is when I did an example script for my problem I got grad1.mean(dim=0) to equal param.grad. Which would show my own method is correct, right?

However, if I iterate through all my data sequentially and calculate an individual loss, and gradients via loss_i.backward() and store model.parameters.grad() within a list (with also calling model.zero_grad() between samples so grads don’t accumulate) I don’t get the same values as those returned via the hooks. So, I’m somewhat confused as to why they’re different.

Are there any caveats to using hooks for this? To give a bit more detail about my loss function it does contain a Laplacian operator, so calculates the Trace of the Hessian matrix of the output of my network w.r.t the inputs. Could this be a potential issue? Perhaps maybe the grad_output term is from the Laplacian rather than the loss function itself? Thereby explaining any potential difference between the sequential method and the batch method?

If it’s not too much could I post an example to illustrate my particular problem?

Thank you for your time!

They should be the same. To debug, just take an example small enough to calculate the gradients by hand and compare against what you are seeing.

For instance, y=wx, set your w=1, and loss y^2/2. Then your gradient is the same as x. Here’s an example of computing two gradient for two different x’s and checking that they match manually computed gradient – Google Colab

You can use this to debug your Hessian trace as well, trace is especially easy when everything is a scalar

1 Like

Hi @Yaroslav_Bulatov!

I’ve managed to create a “relatively” simple example to get the per-sample gradient for the output of my network w.r.t the network’s parameters for all samples and they match with the sequential per-sample method. So, that’s working.

I’m trying to derive it for another loss function which contains the Laplacian, the way in which I calculate the Laplacian is via using the trick shown here. This trick allows you to calculate the Laplacian with batch-support whereas doing it via torch.autograd.functional.hessian which only supports 1 example per call.

As I store the grad_output within a dictionary with the key ['e'] every time I call torch.autograd.grad (which happens N+1 times for an NxN Hessian matrix) it’ll overwrite the grad_output for each layer with the final component of the laplacian as opposed to the entire laplacian itself. Is this method (as shown here) compatible with getting a per-sample gradient via hooks if the per-sample gradients depend on other derivative information as well, which with the code in the link is the laplacian of the output of the network w.r.t the input! (If this makes sense?)

thank you for your help! :slight_smile:

p.s. is it possible to share the colab example document publicly? (i.e. without requesting access?)