Hi everyone,
I am trying to move my code to tensordict(td) and vmap along batch according to How to apply vmap on a heterogeneous tensor - #5 by soulitzer. Since vmap not support batched tensordict, now I just use batchsize = 1.
Nevertheless, everything worked super well until I needed to rewrite derivative module. I have a td with distance between atoms and predict energy by model, and the I need to calculate force by derivating energy w.r.t distance.
Before I use torch.grad
, very straightforward to use. But within vmap, it raises: element 0 of tensors does not require grad and does not have a grad_fn
. I find Batchedtensor
of energy has grad_fn
, but “vmaped” tensor, the real tensor required_grad is false.
Apparently(?), we can not combine vmap with torch.grad, but torch.func.grad
, according to Use vmap and grad to calculate gradients for one layer independently for each input in batch - #2 by AlphaBetaGamma96, and Simple use case: Compete per sample gradient with autograd - #4 by AlphaBetaGamma96
Since torch.func.vmap
can only specify the position of arguments, how can I derivate energy
w.r.t distance
. Here is my pesudo code:
inputs = td({'distance': tensor})
inputs = model(inputs) # predict energy
dedx = torch.grad(inputs['energy'], inputs['distance'], torch.zeros(len(energy), 3))
# how to replace torch.grad with torch.vmap?
So thx for your help!
Also post in discussion of tensordict Combine vmap, func.grad with tensordict · pytorch/tensordict · Discussion #1167 · GitHub