Hello
I have an application which requires computing the derivative of the model’s logits (more accurately - the log softmax of them) with respect to the model’s parameters, multiplied by a constant vector.
Theoretically, this can be done exactly by a single forward pass, as demonstrated in the following PyTorch tutorial: Forward-mode Automatic Differentiation (Beta) — PyTorch Tutorials 1.13.1+cu117 documentation.
Following this tutorial I had managed to write a version of this code for my use case:
def jacobian_vector_product(model, inputs, vector):
params = {name: p for name, p in model.named_parameters()}
tangents = {name: vec_pt for name, vec_pt in zip(params.keys(), vector)}
with torch.autograd.forward_ad.dual_level():
for name, p in params.items():
del_attr(model, name.split("."))
set_attr(model, name.split("."), torch.autograd.forward_ad.make_dual(p, tangents[name]))
out = model(inputs)
lsm = F.log_softmax(out, dim=1)
jvp = torch.autograd.forward_ad.unpack_dual(lsm).tangent
return jvp
(Note here that the original setattr
and delattr
were replaced by a nested version of them).
Indeed, the above code works perfectly when I don’t use any kind of data parallelism.
However, when I wrap the model with nn.DataParallel(), the forward pass in this function yields an error indicating that the model and the inputs are not on the same device.
After a quick debugging I found out that the batch indeed splits to multiple chunks, and that the input is distributed correctly, however all the forward passes are ocuuring on cuda:0 instead of being distributed correspondingly to the inputs.
I think the issue has to do something with the attribute editing loop.
This is because when I tried to comment out the attribute editing loop - forward pass worked properly (although, of course, the jvp equals None).
I have also tried to wrap the model with DataParallel after editing the attributes but I get the same error.
I would appreciate any assistance on this,
Thanks!