Combine vmap, func.grad with tensordict

Hi everyone,

I am trying to move my code to tensordict(td) and vmap along batch according to How to apply vmap on a heterogeneous tensor - #5 by soulitzer. Since vmap not support batched tensordict, now I just use batchsize = 1.

Nevertheless, everything worked super well until I needed to rewrite derivative module. I have a td with distance between atoms and predict energy by model, and the I need to calculate force by derivating energy w.r.t distance.

Before I use torch.grad, very straightforward to use. But within vmap, it raises: element 0 of tensors does not require grad and does not have a grad_fn. I find Batchedtensor of energy has grad_fn, but “vmaped” tensor, the real tensor required_grad is false.

Apparently(?), we can not combine vmap with torch.grad, but torch.func.grad, according to Use vmap and grad to calculate gradients for one layer independently for each input in batch - #2 by AlphaBetaGamma96, and Simple use case: Compete per sample gradient with autograd - #4 by AlphaBetaGamma96

Since torch.func.vmap can only specify the position of arguments, how can I derivate energy w.r.t distance. Here is my pesudo code:

inputs = td({'distance': tensor})
inputs = model(inputs)  # predict energy
dedx = torch.grad(inputs['energy'], inputs['distance'], torch.zeros(len(energy), 3))
# how to replace torch.grad with torch.vmap?

So thx for your help!

Also post in discussion of tensordict Combine vmap, func.grad with tensordict · pytorch/tensordict · Discussion #1167 · GitHub