Hi All,
I’ve been experimenting with creating a custom optimizer class. This optimizer requires extra information in order to work, which requires me to calculate more than 1 loss. I can talk a little bit more about this optimizer but in short, it’s a modified version of KFAC.
From using KFAC, you precondition your gradients via using information gained from register_forward_pre_hook
and regsiter_full_backward_hook
. However, the grad_output
from register_full_backward_hook
I require isn’t the same as what’s returned from the loss function, so I get round this issue by using 2 loss functions and caching the needed values. It goes something like this,
-
add
register_foward_pre_hook
andregister_full_backward_hook
to my model for all layers. -
call my loss function, and cache the correct gradient I need for my model. Although, the backward hooks have the wrong grad_output values. (The forward hook is fine as its independent of the loss function at hand).
-
call an auxiliary (or secondary) loss function (which returns the right grad_output for the backward hooks but has the wrong gradient for my parameters).
-
Grab the correct loss gradients from my cache and overwrite the incorrect gradients (from my auxiliary loss function) with the correct gradients from my loss function, so I effectively solve the problem of needing the backward hooks to return a different
grad_output
value (than the one from my loss function). -
(To summarize in-case that was confusing!) I want the gradient from my loss function but want the
grad_output
values (which comes from thefull_backward_hook
) from a different loss function.
I’ve managed to do this with the caching procedure I’ve described above. However, it only works for one epoch. Which is odd, because PyTorch is telling me from the Traceback that it failed to set up the hooks but it’s clearly done it from the previous epoch? The Full traceback is as follows.
Traceback (most recent call last):
File "test_hooks_double_loss.py", line 120, in <module>
X, acceptance = sampler(burn_in)
File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1065, in _call_impl
return forward_call(*input, **kwargs)
File "~/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "~/Code/Samplers.py", line 94, in forward
self.step()
File "~/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "~/Code/Samplers.py", line 72, in step
log_pdf_chains = self._log_pdf(self.chains).detach_()
File "~/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
return func(*args, **kwargs)
File "~/Code/Samplers.py", line 48, in _log_pdf
return self.network(x)[1].mul(2).detach_()
File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1065, in _call_impl
return forward_call(*input, **kwargs)
File "~/Code/Models.py", line 59, in forward
log_envs = self.log_envelope(x0)
File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1081, in _call_impl
input = bw_hook.setup_input_hook(input)
File "~/.local/lib/python3.8/site-packages/torch/utils/hooks.py", line 158, in setup_input_hook
res, input_idx = self._apply_on_tensors(fn, args)
File "~/.local/lib/python3.8/site-packages/torch/utils/hooks.py", line 142, in _apply_on_tensors
raise RuntimeError("Error while setting up backward hooks. Please open "
RuntimeError: Error while setting up backward hooks. Please open an issue with a code sample to reproduce this.
Although, the Traceback states to make a reproducible example it seems quite difficultt given I’ve probably got near 1000 lines of code over numerous files. Would sharing perhaps parts of main.py
with my caching be an acceptable start to debugging this?
Any help with this would be greatly appreciated!
Thank you!
Edit: I assume there’s something wrong with self.log_envelope(x0)
? The source code for that layer is given below,
class LogEnvelope(nn.Module):
def __init__(self, num_particles, num_dets, bias=False) -> None:
super(LogEnvelope, self).__init__()
self.num_particles = num_particles
self.num_dets = num_dets
self.weight = nn.Parameter(torch.empty(self.num_dets, self.num_particles))
self.reset_parameters()
def reset_parameters(self) -> None:
torch.nn.init.uniform_(self.weight, 1., 2.)
def forward(self, x0):
return torch.einsum("bin,di->bdin", x0, self.weight).pow(2).sum(dim=-1)