I have a batch of inputs
import torch
x = torch.arange(10)
and a complicated routine that depends on some parameter C
C = torch.tensor(-5., requires_grad=True)
I use that routine to compute some outputs y
. Rather than dump the whole complicated routine, I’ll make up something simple that gives the same issues.
y = C * x
Now, I want to compute the derivative of y with respect to the parameter dy/dC. In this case of course I know it’s x, but my ‘real’ routine is more involved and the derivative is nontrivial to compute / program.
I know how to do it elementwise,
torch.tensor(tuple(torch.autograd.grad(yi, C, retain_graph=True)[0] for yi in y))
# tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]) # the correct answer!
But in my real problem this is pretty slow. Can I eliminate the python-level for loop?
I first tried
torch.autograd.grad(y, C)
torch.autograd.grad(y, C.expand(10))
# both give
# RuntimeError: grad can be implicitly created only for scalar outputs
and then
torch.autograd.grad(y, C, grad_outputs=torch.ones(10), allow_unused=True)
# (tensor(45.),)
which is close, since that’s x.sum()
; I think it did essentially torch.dot(torch.ones(10), dy/dC == x)
Well, that’s almost what I want! Since torch.dot(torch.ones(10), ...) == torch.matmul(torch.ones(10), ...)
can I replace torch.ones --> torch.eye
and get dy/dC as desired?
torch.autograd.grad(y, C, grad_outputs=torch.eye(10))
# RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([10, 10]) and output[0] has a shape of torch.Size([10]).
No, that doesn’t work. Perhaps if I expand C?
torch.autograd.grad(y, C.expand(10), grad_outputs=torch.ones(10))
# RuntimeError: One of the differentiated Tensors appears to not have been used in the graph.
hm, I suppose the computational graph can’t “see through” the fact that every entry of C.expand
is C
. That makes sense.
Maybe I need to differentiate elementwise and use vmap?
import functorch
def ddC(y):
d = torch.autograd.grad(y, C, is_grads_batched=True)
# d = torch.autograd.grad(y, C.expand(10), is_grads_batched=True) # gives the same error
return d[0]
functorch.vmap(ddC)(y)
# ValueError: vmap(vjp, in_dims=0, ...)(<inputs>): Got in_dim=0 for an input
# but the input is of type <class 'NoneType'>. We cannot vmap over non-Tensor
# arguments, please use None as the respective in_dim
I don’t understand this error, and I don’t understand where NoneType
could be coming from.
Any advice? How can I get the same result as the element-wise differentiation without a slow python for loop? Or is that what I should stick to?
Should I be considering forward-mode autodifferentiation instead?