Efficient computation with multiple grad_output's in autograd.grad

I want to compute Jacobian matrices using pytorch’s autograd. Autograd natively computes Jacobian-vector products, so I’d simple like to pass an identity matrix to obtain the full Jacobian (ie, Jv = JI = J).

One wrinkle: I’d like to implement both standard reverse-mode AD computation for the Jacobian, but also a forward-mode version (which should be faster for most of my applications) using the following trick due to Jamie Townshend:

I’ve actually gotten the former (reverse-mode) working, using a few hits found online and some fairly straightforward work:

def rev_jacobian(fxn, x, n_outputs, retain_graph):
    """
    the basic idea is to create N copies of the input
    and then ask for each of the N dimensions of the
    output... this allows us to compute J with pytorch's
    jacobian-vector engine
    """

    # expand the input, one copy per output dimension
    n_outputs = int(n_outputs)
    repear_arg = (n_outputs,) + (1,) * len(x.size())
    xr = x.repeat(*repear_arg)
    xr.requires_grad_(True)

    # both y and I are shape (n_outputs, n_outputs)
    #  checking y shape lets us report something meaningful
    y = fxn(xr).view(n_outputs, -1)

    if y.size(1) != n_outputs: 
        raise ValueError('Function `fxn` does not give output '
                         'compatible with `n_outputs`=%d, size '
                         'of fxn(x) : %s' 
                         '' % (n_outputs, y.size(1)))
    I = torch.eye(n_outputs, device=xr.device)

    J = autograd.grad(y, xr,
                      grad_outputs=I,
                      retain_graph=retain_graph,
                      create_graph=True,  # for higher order derivatives
                      )

    return J[0]

However, for the forward-mode version I can’t seem to get autograd to behave. Somehow, I cannot seem to get it to accept the identity matrix as an argument for grad_outputs and NOT return a sum of the columns of the final Jacobian I’d like. But given the success of the reverse-mode implementation, it seems like it should be easy!!

Here’s the current code, with a commented out version of what I’d like to do:

def fwd_jacobian(fxn, x, n_outputs, retain_graph):
    """
    This implementation is very similar to the above, but with
    one twist. To implement a forward-mode AD with rev-mode
    calls, we first compute the rev-mode VJP for one vector (v)
    then we call d/dv(VJP) `n_outputs` times, one per basis vector,
    to obtain the Jacobian.
    
    This should be faster if `n_outputs` > "n_inputs"

    References
    ----------
    .[1] https://j-towns.github.io/2017/06/12/A-new-trick.html
         (Thanks to Jamie Townsend for this awesome trick!)
    """

    xd = x.detach().requires_grad_(True)
    n_inputs = int(xd.size(0))

    # first, compute *any* VJP 
    v = torch.ones(n_outputs, device=x.device, requires_grad=True)
    y = fxn(xd.view(1,n_inputs)).view(n_outputs)

    if y.size(0) != n_outputs:
        raise ValueError('Function `fxn` does not give output '
                         'compatible with `n_outputs`=%d, size '
                         'of fxn(x) : %s'
                         '' % (n_outputs, y.size(0)))

    vjp = torch.autograd.grad(y, xd, grad_outputs=v, 
                              create_graph=True,
                              retain_graph=retain_graph)[0]
    assert vjp.shape == (n_inputs,)

    # TODO somehow the repeat trick does not work anymore
    #      now that we have to take derivatives wrt v
    #      so loop over basis vectors and compose jacobian col by col

    I = torch.eye(n_inputs, device=x.device)
    J = []
    for i in range(n_inputs):
        Ji = autograd.grad(vjp, v,
                          grad_outputs=I[i],
                          retain_graph=retain_graph,
                          create_graph=True,  # for higher order derivatives
                          )
        J.append(Ji[0])

    return torch.stack(J).t()

Looping over each column of the final Jacobian is slow, of course, and wasteful – I’m computing the Jacobian N times, and throwing most of it away each time. Since this is the bottleneck in my code, it would be sweet to do it all in one go :grin:.

Hoping someone with more knowledge of the guts of autograd can advise on the proper use here. Thanks in advance!!

PS. I won’t post just yet for sake of brevity, but if anyone is interested I can post a test case as well to work from.

Hi,

The base implementations here for me are from Adam’s gists: here is the one with full Jacobian and Hessian. And here is the one for the forward mode trick :slight_smile:

Thanks for those, but they don’t quite answer the question :). My question is: can I do this without looping over the columns of the Jacobian (which Adam does)?

I’m afraid you cannot. This is a limitation of automatic differentiation… You can only efficiently do vJ or Jv products.

Right. I suppose I was operating under the assumption pytorch’s implementation was flexible enough to accept an identity matrix as “v”. That may be wrong.

So if you compare these two implementations, the first gives significantly faster run times (in my hands) than the second. Factor of 10-100x depending on the case. Perhaps this is just to some skipped CUDA overhead? I am trying to understand this behavior and see if I can exploit the same kind of speedup using the “forward mode trick”.

def _rev_jacobian(fxn, x, n_outputs, retain_graph=True):
    """
    the basic idea is to create N copies of the input
    and then ask for each of the N dimensions of the
    output... this allows us to compute J with pytorch's
    jacobian-vector engine
    """

    # expand the input, one copy per output dimension
    n_outputs = int(n_outputs)
    repear_arg = (n_outputs,) + (1,) * len(x.size())
    xr = x.repeat(*repear_arg)
    xr.requires_grad_(True)

    # both y and I are shape (n_outputs, n_outputs)
    #  checking y shape lets us report something meaningful
    y = fxn(xr).view(n_outputs, -1)

    if y.size(1) != n_outputs:
        raise ValueError('Function `fxn` does not give output '
                         'compatible with `n_outputs`=%d, size '
                         'of fxn(x) : %s'
                         '' % (n_outputs, y.size(1)))
    I = torch.eye(n_outputs, device=xr.device)

    J = autograd.grad(y, xr,
                      grad_outputs=I,
                      retain_graph=retain_graph,
                      create_graph=True,  # for higher order derivatives
                      )

    return J[0]


def _rev_jacobian_simple(fxn, x, n_outputs, retain_graph=True):

    n_outputs = int(n_outputs)

    xd = x.detach()
    xd.requires_grad_(True)
    n_inputs = int(xd.size(0))

    y = fxn(xd.view(1,n_inputs)).view(n_outputs)

    if y.size(0) != n_outputs:
        raise ValueError('Function `fxn` does not give output '
                         'compatible with `n_outputs`=%d, size '
                         'of fxn(x) : %s'
                         '' % (n_outputs, y.size(0)))
    I = torch.eye(n_outputs, device=xd.device)

    J = torch.zeros([n_outputs, n_inputs], device=xd.device)
    for i in range(n_outputs):
        J[i,:] = autograd.grad(y, xd,
                               grad_outputs=I[i],
                               retain_graph=retain_graph,
                               create_graph=True,  # for higher order derivatives
                               )[0]
    return J

What if you use the same repeat trick for forward mode? IE, replicate x and v vector and do the same as in reverse mode?

https://colab.research.google.com/drive/1tcm7Lvdv0krpPdaYHtWe7NA2bDbEQ0uj#scrollTo=-jwISHOycQ8x

That would be faster but consume much more memory. That might be ok in your case though !

You can do the same for the “forward mode trick” as well I guess. But not sure this will make the Jacobian computation faster than this one.

Tricks on tricks for the win?

@Yaroslav_Bulatov and @albanD thanks. In the end I copied Yaroslav – doing the “repeat” trick twice. I was hoping I would not have to re-compute N reverse passes, but it seems unavoidable.

With code very similar to Yaroslav’s I get a ~5x speed boost using the repeat trick.

And consistent with @albanD 's expectations the faster, non-looping implementations do use much more memory. So some tradeoff. For me speed is more important and I have memory to spare, so I will probably roll with that :).

Thanks to both of you guys!

2 Likes