Hi and thanks for your answers!
I have tried to implement the double-back trick for computing jvp
s as in:
https://j-towns.github.io/2017/06/12/A-new-trick.html
and here:
and is also in the PyTorch’s original implementation of jvp
:
https://pytorch.org/docs/stable/generated/torch.autograd.functional.jvp.html
However, it seems that syncBN continues to make problems.
I get the following error:
File "/home/mitchell/DivideMix/meta_utils.py", line 261, in jacobian_vector_product
jvp = torch.autograd.grad(g, v, vector)[0]
File "/home/mitchell/miniconda3/envs/deep_learn/lib/python3.9/site-packages/torch/autograd/__init__.py", line 276, in grad
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: derivative for aten::batch_norm_backward_elemt is not implemented
Here is my implementation:
def jacobian_vector_product(model, inputs, vector, method='double-back-trick'):
if method == 'forward':
'''
jvp products using forward mode AD as demonstrated in:
https://pytorch.org/tutorials/intermediate/forward_ad_usage.html
'''
params = {name: p for name, p in model.named_parameters()}
tangents = {name: vec_pt for name, vec_pt in zip(params.keys(), vector)}
with torch.autograd.forward_ad.dual_level():
for name, p in params.items():
del_attr(model, name.split("."))
set_attr(model, name.split("."), torch.autograd.forward_ad.make_dual(p, tangents[name]))
out = model(inputs)
lsm = F.log_softmax(out, dim=1)
jvp = torch.autograd.forward_ad.unpack_dual(lsm).tangent
return jvp
elif method == 'double-back-trick':
'''
jvp products using double backward as demonstrated in:
https://j-towns.github.io/2017/06/12/A-new-trick.html
https://discuss.pytorch.org/t/forward-mode-ad-vs-double-grad-trick-for-jacobian-vector-product/159037
'''
out = model(inputs)
lsm = F.log_softmax(out, dim=1)
v = torch.zeros_like(lsm, requires_grad=True)
g = torch.autograd.grad(lsm, model.parameters(), v, create_graph=True)
jvp = torch.autograd.grad(g, v, vector)[0]
return jvp
else:
raise NotImplementedError
Am I doing a mistake or is it another issue with syncBn?