Right. I suppose I was operating under the assumption pytorch’s implementation was flexible enough to accept an identity matrix as “v”. That may be wrong.
So if you compare these two implementations, the first gives significantly faster run times (in my hands) than the second. Factor of 10-100x depending on the case. Perhaps this is just to some skipped CUDA overhead? I am trying to understand this behavior and see if I can exploit the same kind of speedup using the “forward mode trick”.
def _rev_jacobian(fxn, x, n_outputs, retain_graph=True):
"""
the basic idea is to create N copies of the input
and then ask for each of the N dimensions of the
output... this allows us to compute J with pytorch's
jacobian-vector engine
"""
# expand the input, one copy per output dimension
n_outputs = int(n_outputs)
repear_arg = (n_outputs,) + (1,) * len(x.size())
xr = x.repeat(*repear_arg)
xr.requires_grad_(True)
# both y and I are shape (n_outputs, n_outputs)
# checking y shape lets us report something meaningful
y = fxn(xr).view(n_outputs, -1)
if y.size(1) != n_outputs:
raise ValueError('Function `fxn` does not give output '
'compatible with `n_outputs`=%d, size '
'of fxn(x) : %s'
'' % (n_outputs, y.size(1)))
I = torch.eye(n_outputs, device=xr.device)
J = autograd.grad(y, xr,
grad_outputs=I,
retain_graph=retain_graph,
create_graph=True, # for higher order derivatives
)
return J[0]
def _rev_jacobian_simple(fxn, x, n_outputs, retain_graph=True):
n_outputs = int(n_outputs)
xd = x.detach()
xd.requires_grad_(True)
n_inputs = int(xd.size(0))
y = fxn(xd.view(1,n_inputs)).view(n_outputs)
if y.size(0) != n_outputs:
raise ValueError('Function `fxn` does not give output '
'compatible with `n_outputs`=%d, size '
'of fxn(x) : %s'
'' % (n_outputs, y.size(0)))
I = torch.eye(n_outputs, device=xd.device)
J = torch.zeros([n_outputs, n_inputs], device=xd.device)
for i in range(n_outputs):
J[i,:] = autograd.grad(y, xd,
grad_outputs=I[i],
retain_graph=retain_graph,
create_graph=True, # for higher order derivatives
)[0]
return J