Computing backward() in parallel

Hi,
I have a use case in which I have multiple loss functions on which I have to call backward without using any reduction like mean or sum. I want to calculate gradients for different losses parallely.

losses = [loss1, loss2, loss3]
losses.backward()

print(param.grad)
## It should contain the jacobian of the gradients

In particular I am trying to implement this functionality in pytorch https://github.com/tianheyu927/PCGrad/blob/c5fbd7c856526373828074f06875230f7f3ee79e/PCGrad_tf.py#L39

Is it possible to do it parallely without a for loop?

1 Like

Hi

What do you expect the grad to contain exactly? The sum of the gradient for the 3 losses?

I want grad to exactly contain the stack of gradients wrt to three losses. How I am doing it at present is this way:

    flattened_grads = [torch.Tensor([]).to(device) for i in range(len(losses))]
    
 

    # Accumulate the gradient vectors
    for i, loss in enumerate(losses):
        for optim in optimizers:
            optim.zero_grad()
        loss.backward(retain_graph = True)
        for net in nets:
            for name, param in net.named_parameters():
                flattened_grads[i] = torch.cat([flattened_grads[i],param.grad.view(-1)])

I want to optimize this code to run faster and possibly in parallel. As the computation graphs are similar I should ideally backprop each node once.

If you want three different gradients computed, you will need to call backward 3 times I’m afraid :confused:

Since they are dependent, each Node needs to be recomputed. So “running in parallel” would only run once a Node that is 3 times bigger. So not leading to any improvement.

There is a tf way of doing this used here https://github.com/tianheyu927/PCGrad/blob/c5fbd7c856526373828074f06875230f7f3ee79e/PCGrad_tf.py#L39. Isn’t something like this possible in Pytorch?

If not ,then are there anyother ways to optimize the above code my gpu still runs at around 50% usage and increasing the batch size is deterimental to my results?

In general in pytorch, multiple backward pass cannot run concurrently in different threads so there is no benefit in adding such functionality :confused:
We are starting working on this but this is quite tricky as it have very deep implications on how the low level engine behaves.

More generally, pytorch is not great for small workloads at the moment as the overhead of the framework is non negligible in eager mode. But we are working on it :slight_smile:

1 Like

They seem to be doing this https://github.com/f-dangel/backpack and they too https://github.com/ChenAo-Phys/pytorch-Jacobian.
But they don’t support my use case directly. Any method based on this suitable here ?

You can also use torch.autograd.functional.jacobian() if you use pytorch >1.5.
But it will still be done using a for-loop under the hood I’m afraid.

They claim to be doing better then https://github.com/f-dangel/backpack for loop’s. Hope that pytorch adds this feature soon.

I don’t remember backpack being able to compute multiple losses at the same time in the general case.
In special cases they are able to compute what would require using a naive approach multiple backward passes in a single one. But these are only special cases.

1 Like

Hi! Greetings from more than one year later!

The functionality you talked about is exactly something what I need now. Do you have any progress on this functionality?

Thanks!

I’ve just found a torch.vmap() in a newer version of pytorch. I’ll try that first.

Any luck with this? It would be quite beneficial to access the gradients of the three losses separately and using the for loop just pile up the computational time of each backward

Hey!
If the losses are fully independent, there isn’t much better that can be done than running each of them I’m afraid.
If they’re related, I would suggest using vmap or tools like backpack or element-wise gradients from core.

A work-around I found these days was replicating some parts of the network and load_state from the main branch, then assigning the combination of gradiets (possibly as just a summation or whatever) to the main branch by hand. This way you trade-off speed with memory. It works well but its still a trade-off. It would be interesting to see such implementation on pytorch directly.