Can the new functional autograd take batches? Also, is it more efficient to compute a hessian with the new functional autograd than it is using the old autograd?

I have two questions. They’re pretty much self-explanatory once you read the title. Basically, I’m debating whether it’s worth using the new functional hessian over the old autograd method. Major benefits would be time efficiency speedups or batch vectorization.

Hi,

  • It depends what you mean by batch here. If it is a dimension as any other that your function handles, then yes. If it is a dimension that should be handled differently (like element-wise hessian?), then no. I would be interested though to have a code sample of what you would like here exactly and what is your use case (we want to improve this to make it as feature complete for users as possible).
  • Right now it won’t be more efficient unfortunately. But I am actively working on benchmark + new autograd features that will speed things up. If you use this new API, then you will benefit from these future improvements when they come out. If you do it by hand, you will have to incorporate them into your code base.

Hi, thanks for your answer! I’ll try to give an example of what I’m doing in pseudocode:

Let’s say I have 32 samples in a batch which I run through a module net1 that maps them to size 10:

[32 x 10]
net1_out = net1(samples)

Now I want the hessian of another module net2 w.r.t. these:

hessian(net2, net1_out, create_graph=True)

What I mean by “in a batch” is that I would like the output to have shape: [32 x 10 x 10]

However, the output seems to be: [32 x 10 x 32 x 10]

In other words, it seems to be computing the hessians across batch items, which is more expensive than what I need.

1 Like

Yes, these functions work with functions in general and don’t know about specific constructs like “batch”.
You will have to define a function that takes inputs of size (10) and outputs of size (10), and run hessian for each element in the batch. (The same way all the others constructs in torch.autograd don’t know about batch).

So one way I could overcome this is to define the second order derivative as its own Pytorch module without using autograd, and then run my inputs through that in batches, which would be significantly faster because it would take advantage of the parallelism of batch processing as opposed to iterating over each element in a for loop. Of course, that’s kind of tedious to do, even with a standard MLP, especially if I want to use dropout. But it’s doable. Would it be possible to have batch vectorization in a future release? I think JAX autograd does it. In general, it would be very helpful to be able to treat derivatives of modules as modules. Then not only could you process in batches, but you could also reuse the same ones without having to recompute them. In functions where different inputs might have the same derivative function, you could also get away with reusing the same derivative module as opposed to recomputing a new one. Does that make sense?

A tool like this would be really useful:
http://www.matrixcalculus.org/matrixCalculus

Check out the “export to Python” option. It spits out numpy, but if we had the same sort of tool for tensor calculus with Pytorch, that would make higher order derivatives so much more efficient. That way, a function can be provided as input and a function can be provided as output. Unfortunately, the tool I linked doesn’t do tensor calculus, only matrix calculus.

I used that tool to define the hessian of an MLP by hand, and the speedup was orders of magnitude faster compared to the functional autograd hessian… I can share code if you like.

Specifically, functional.hessian took 15.926970720291138 seconds, whereas my hand-derived version took 0.2118682861328125 seconds. Both returned the same results.

Wow, I wondered if maybe the batch iteration was causing the major slowdown, but when I compared the functional hessian on just a single sample vs. my hand-derived hessian on a batch of 64 samples, my version (despite the 64x disadvantage) still ran more than twice as fast.

Hi,

Yes writing directly the formula by hand will always be faster :slight_smile:
Also doing batch 1 or batch 64 for ops that support batch won’t make a big change in runtime.

Batch vectorization is WIP yes.

2 Likes

Thanks! I’m glad autograd is being expanded on. The functional API is a good step. Batch vectorization would make it even better! My only confusion is, the pipeline I used to make my version of the hessian function was entirely automatic: I plugged my MLP function into the matrix calculus tool above, copied and pasted the NumPy code it produced, and simply changed the NumPy calls to Pytorch. Couldn’t the same pipeline be built into Pytorch to enable even faster derivatives? Sorry if this question is naive; I know an MLP is a very basic use case and generalizing to more sophisticated architectures would probably require a lot more effort.

I think the main shortcomings of these pipelines is that it might fail for more complex architectures.
Lots of user don’t even know what the mathematical function corresponding to their model is. And writing it would be too complex for a symbolic differentiator to handle.
In particular some functions are very easy to write in code but very hard to write mathematically (with complex conditions or for loops for example).

But I think your point makes sense that for simple functions we might be able to have such a tool. That could be a third party library.

That makes sense. I don’t know enough about the limitations of symbolic differentiators. Such a library could be done for common off-the-shelf models, but I agree that generalizing it to arbitrary architectures would be very complicated. Thanks for your help and answering my questions! :slight_smile:

1 Like

Hi, @albanD, has there been any progress on computing a batch hessian as Sam asked for? I also want to compute the hessian of a 32 separate 10 dimensional vectors, i.e. [32 x 10] input with [32 x 10 x 10] output (so 32 hessians of the 10 dimensional vectors). Essentially I would like it to work as torch.matmul or torch.linalg.inv do where they treat the last two dimensions as the matrix and everything else is looped over.

I don’t see a way to do this in a the documentation, is there a preferred method? I’m currently running a for loop over the 0 dimension and then calling hessian on each vector but it is slow.

thank you for any help!

The hessian function now has an experiemental vectorize flag: torch.autograd.functional.hessian — PyTorch 1.9.0 documentation
But that doesn’t have full support for everything right now and the performance might be as high for some ops.

1 Like

Thanks, it seems to require an incredible amount of memory. My GPU has 8 GB of memory and prior to calling hessian I have only allocated 121 kb according to torch.cuda.memory_allocated(0).

I am attempting to calculate a 1x1 Hessian for 360 x 20 batches, so my input is of the form [320 x 20 x 1]
When I call hessian with vectorize=False, it creates the [320 x 20 x 1 x 320 x 20 x 1] as expected. If I try to call with vectorize=True, I get an out of memory error when it takes more than 5 GB.

Is this expected behavior? I assume a 320 x 20 x 1 matrix is far smaller than most models. Am I correct to assume the output will be 320 x 20 x 1 x 1? The documentation doesn’t really explain what the vectorize flag expected behavior is except that it uses the vmap

Thank you so much for you help. If its unclear why this requires so much memory I’ll try to figure out other options

EDIT: If I run a for loop over the 320 x 20 batches to calculate each 1 x 1 Hessian with create_graph=True, it takes about 2.8 GB

If your input is of shape 320 x 20 x 1 then the hessian is expected to be of size [320 x 20 x 1 x 320 x 20 x 1].
The vectorize flag doesn’t change what is computed. Only the way it is computed.

1 Like

There seems to be an easy way to get Hessian in a batched manner:
All you need to do is ensure that your function does not operate across samples in a batch
Here is a code snippet:

import torch
from torch.autograd.functional import jacobian

p=3 # dimension
n=10 # number of samples in the batch
x = torch.randn(n, p, requires_grad=True)

Q_12 = torch.randn(p, p)
Q = Q_12.T @ Q_12 # random symmetric positive-semidefinite matrix

# quadratic function. 0.5*xT Q x where x is every row of input tensor
f = lambda x: 0.5*torch.tensordot(x, x @ Q) 

def get_sum_of_gradients(x):
    h = f(x)
    h.backward(create_graph=True)
    return x.grad.sum(0)

Hessian_per_sample = jacobian(get_sum_of_gradients, x).swapaxes(0, 1)

# Check if Hessian for every sample is Q
print(torch.all(torch.isclose(Hessian_per_sample, torch.tile(Q, (n, 1, 1)))).item())
1 Like

Great solution @parthe - thanks! After testing it for correctness & speed I made these changes:

def get_sum_of_gradients(x):
    h = f(x)
    return torch.autograd.grad(h, x, create_graph=True)[0].sum(0)

Hessian_per_sample = jacobian(get_sum_of_gradients, x, vectorize=True).swapaxes(0, 1)

The alteration to get_sum_of_gradients was purely to avoid memory leak warnings.

The addition of vectorize=True led to significant speed-ups (~10x) for the cases I tested.

As an example test-case, I computed the Hessian of f(x), where f is a scalar-valued 3-layer feed forward neural network. I used a batch-size of 500 and the input / hidden dims of 200. Computation was on GPU. I got the following results:

  • grad + Jacobian (vectorize=True) : 0.02 seconds
  • grad + Jacobian (vectorize=False) : 0.22 seconds
  • For-loop over batch + Hessian (vectorize=True) : 7.24 seconds
  • For-loop over batch + Hessian (vectorize=False): 74.53 seconds
1 Like