Batch-differentiating with respect to single parameter

I have a batch of inputs

``````import torch
x = torch.arange(10)
``````

and a complicated routine that depends on some parameter C

``````C = torch.tensor(-5., requires_grad=True)
``````

I use that routine to compute some outputs `y`. Rather than dump the whole complicated routine, I’ll make up something simple that gives the same issues.

``````y = C * x
``````

Now, I want to compute the derivative of y with respect to the parameter dy/dC. In this case of course I know it’s x, but my ‘real’ routine is more involved and the derivative is nontrivial to compute / program.

I know how to do it elementwise,

``````torch.tensor(tuple(torch.autograd.grad(yi, C, retain_graph=True)[0] for yi in y))
# tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]) # the correct answer!
``````

But in my real problem this is pretty slow. Can I eliminate the python-level for loop?

I first tried

``````torch.autograd.grad(y, C)
# both give
# RuntimeError: grad can be implicitly created only for scalar outputs
``````

and then

``````torch.autograd.grad(y, C, grad_outputs=torch.ones(10), allow_unused=True)
# (tensor(45.),)
``````

which is close, since that’s `x.sum()`; I think it did essentially `torch.dot(torch.ones(10), dy/dC == x)`

Well, that’s almost what I want! Since `torch.dot(torch.ones(10), ...) == torch.matmul(torch.ones(10), ...)` can I replace `torch.ones --> torch.eye` and get dy/dC as desired?

``````torch.autograd.grad(y, C, grad_outputs=torch.eye(10))
# RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([10, 10]) and output[0] has a shape of torch.Size([10]).
``````

No, that doesn’t work. Perhaps if I expand C?

``````torch.autograd.grad(y, C.expand(10), grad_outputs=torch.ones(10))
# RuntimeError: One of the differentiated Tensors appears to not have been used in the graph.
``````

hm, I suppose the computational graph can’t “see through” the fact that every entry of `C.expand` is `C`. That makes sense.

Maybe I need to differentiate elementwise and use vmap?

``````import functorch
def ddC(y):
return d[0]

functorch.vmap(ddC)(y)
# ValueError: vmap(vjp, in_dims=0, ...)(<inputs>): Got in_dim=0 for an input
# but the input is of type <class 'NoneType'>. We cannot vmap over non-Tensor
# arguments, please use None as the respective in_dim
``````

I don’t understand this error, and I don’t understand where `NoneType` could be coming from.

Any advice? How can I get the same result as the element-wise differentiation without a slow python for loop? Or is that what I should stick to?

Should I be considering forward-mode autodifferentiation instead?

Hi Evan!

will let you compute the derivatives of a single scalar (e.g., a loss) with
respect to a batch of variables (i.e., compute the gradient) in a single
pass, but it won’t compute the derivatives of a batch of results with
respect to a single variable in a single pass.

Yes (but with the proviso that forward-mode autograd is still in beta /
experimental).

I have not used forward-mode autograd for anything real so I can’t
speak to its stability (nor performance), but here is an illustration applied

``````>>> import torch
>>> print (torch.__version__)
1.12.0
>>>
>>> x = torch.arange (10)
>>> C = torch.tensor (-5., requires_grad=True)
>>> y = C * x
>>>
>>> # compute "batch-derivative" with loop over backward-mode autograd
>>> resultA = torch.tensor (tuple (torch.autograd.grad (yi, C, retain_graph=True)[0] for yi in y))
>>> resultA
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
>>>
>>>
>>> # compute "batch-derivative" with one pass of forward-mode autograd
>>> Ct = torch.tensor (1.0)
...     y_dual = C_dual * x
...
>>> resultB
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
>>> torch.equal (resultB, resultA)
True
``````

Best.

K. Frank

Hi K. Frank,

Thanks for your response. I will try forward-mode autodiff, we’ll see what happens. If I have to write code / do a manual loop, so be it. What’s peculiar to me is that the computational graph for each entry of y is the same graph; shouldn’t it be re-usable?

Can you explain what the `is_grads_batched` option does for `torch.autograd.grad` with an example? The documentation is pretty opaque as to what the actual use-case is, and can you explain why my `vmap(ddC)` approach failed? I thought that was the way to avoid the manual loop over y.

Hi Evan!

To be clear, you are reusing the computation graph.

You run `y = C * x` once (the forward pass), which creates the computation
graph just once. Then when you run
`torch.autograd.grad(yi, C, retain_graph=True)` in your `for yi in y`
loop, you use the computation graph and retain it so that the same
computation graph can be used again in subsequent iterations of the
loop (without recreating it). (You are performing multiple backward passes,
but through the same, reused computation graph.)

I’ve never used `is_grads_batched` nor `vmap` (although the documentation
suggests that `is_grads_batched` uses `vmap` under the hood).

Best.

K. Frank

Hi Evan!

A quick follow-up:

When using `is_grads_batched` you have to pass `torch.autograd.grad()`
a batch of `grad_outputs`, packaged as a tensor. In your use case
`torch.eye()` suffices:

``````>>> import torch
>>> print (torch.__version__)
1.12.0
>>>
>>> x = torch.arange (10)
>>> C = torch.tensor (-5., requires_grad=True)
>>> y = C * x
>>>
>>> # compute "batch-derivative" with loop over backward-mode autograd
>>> resultA = torch.tensor (tuple (torch.autograd.grad (yi, C, retain_graph=True)[0] for yi in y))
>>> resultA
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
>>>
>>>
>>> resultB
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
>>> torch.equal (resultB, resultA)
True
``````

It’s interesting to note that I ran this example on a version 1.12 without
`functorch` or (to my knowledge) `vmap()` installed, so I have no idea
how the batch of gradient computations is being vectorized. (Maybe
there’s a secret copy of `vmap()` hiding somewhere, or maybe this falls
back to a simple loop with no speed-up.)

Best.

K. Frank

If I have the following operations

``````x = th.randn(10, 1)
I was looking at this page: torch.vmap — PyTorch Tutorials 1.12.1+cu102 documentation and what was posted above but I do not fully understand what the role of `basis_vectors` is and forward-mode AD code structure with tangent and primal.