Quickly get individual gradients (not sum of gradients) of all network outputs

Right now, I’m doing:

output_gradients = []

for output in net_outputs:
    tmp_grad = {}
    net.zero_grad()
    output.backward(retain_graph=True)
    for name, param in net.named_parameters():
         tmp_grad[name] = param.grad
    output_gradients.append(tmp_grad)

Since I have to call backward on each output, the backward passes are not parallelized and thus, the code is pretty slow.

Is there a faster way? Thanks!

1 Like

by default we only support accumulated gradients, so this is not easy to do.
If you dont have memory constraints, you can use the torch.autograd.grad interface to compute separate gradients in one shot. (they wont be in .grad, but will be explicitly returned, so some manual book-keeping is needed)

2 Likes

Thanks for the tips! Though we were able to rewrite the math to avoid doing this.

@smth Can you elaborate more on how to use the torch.autograd.grad interface to compute separate gradients in one shot? Thanks.

grads = torch.autograd.grad(loss,parameters,retain_graph=True)
would return the gradients as a tuple matching exactly the parameters provided. So that you don’t need to make a for loop to get the gradients as the code example above. Same size same order as the parameters provided.

1 Like

Is it possible to do backward propagation in parallel way? Thanks!

1 Like

Dear @evcu, in your example, loss is scalar and parameters is a vector. However, the question is for multiple losses, and computing these without summation. For me, @smth’s reference to torch.autograd.grad wasn’t helpful either. Its docs start with: “Computes and returns the sum of gradients of outputs w.r.t. the inputs”, and that seems to be what it does. Under what configuration does grad return individual gradients (at the expense of memory)?

As aside, This is a pretty useful feature for differential privacy.

4 Likes

Exactly! Any updates on this? Would love to have an option to get individual gradients.

Hi @Skinish,

Look at the functorch library, which is designed to calculate per-sample gradients efficiently. Their github repo can be found here

1 Like

For what I see in here, you need to perform the forward call in parallel to calculate the gradients. However, that consumes too much memory (RAM or GPU memory).
What if we already have a single (X, Y) and just want to calculate gradients for all the elements in Y?

If you’re hitting an OOM error, you can always ‘chunk’ the vmap operation. So send multiple mini-batches and then just concatenate the results.

In the example you shared, a ‘chunk’ version would be something like this,

import torch
from functorch import make_functional, vmap, grad

model = torch.nn.Linear(3, 3)
data = torch.randn(64, 3)
targets = torch.randn(64, 3)

func_model, params = make_functional(model)

def compute_loss(params, data, targets):
    preds = func_model(params, data)
    return torch.mean((preds - targets) ** 2)

per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)

out_full = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)

nchunks=8                   #number of chunks
nparams=len(params) #number of params, 2 in this case (weight & bias)

#perform vmap in chunks 
out = [vmap(grad(compute_loss), in_dims=(None, 0, 0))(params, data_chunk, targets_chunk) for data_chunk, targets_chunk in zip(data.chunk(nchunks), targets.chunk(nchunks))] 
#recursively flatten list
out = [item for sublist in out for item in sublist] 
#re-map to correct shape and concatenate
out_chunk = [torch.cat([out[nparams*chunk+p] for chunk in range(nchunks)], dim=0) for p in range(nparams)] 

#check both methods for completeness
for i in range(2):
  print("Param: %i Match? %s" % (i, torch.allclose(out_full[i], out_chunk[i])))
"""
Returns
Param: 0 Match? True
Param: 1 Match? True
"""

If you still have an issue, you can open an issue on functorch’s github repo here. There is some development of to ‘chunk’ vmap as seen in issue #680 but it’s still under development I believe!

1 Like

Thanks for the reply :slight_smile:
I opened an issue as you suggested