RuntimeError: Error while setting up backward hooks. Please open an issue with a code sample to reproduce this

Hi All,

I’ve been experimenting with creating a custom optimizer class. This optimizer requires extra information in order to work, which requires me to calculate more than 1 loss. I can talk a little bit more about this optimizer but in short, it’s a modified version of KFAC.

From using KFAC, you precondition your gradients via using information gained from register_forward_pre_hook and regsiter_full_backward_hook. However, the grad_output from register_full_backward_hook I require isn’t the same as what’s returned from the loss function, so I get round this issue by using 2 loss functions and caching the needed values. It goes something like this,

  • add register_foward_pre_hook and register_full_backward_hook to my model for all layers.

  • call my loss function, and cache the correct gradient I need for my model. Although, the backward hooks have the wrong grad_output values. (The forward hook is fine as its independent of the loss function at hand).

  • call an auxiliary (or secondary) loss function (which returns the right grad_output for the backward hooks but has the wrong gradient for my parameters).

  • Grab the correct loss gradients from my cache and overwrite the incorrect gradients (from my auxiliary loss function) with the correct gradients from my loss function, so I effectively solve the problem of needing the backward hooks to return a different grad_output value (than the one from my loss function).

  • (To summarize in-case that was confusing!) I want the gradient from my loss function but want the grad_output values (which comes from the full_backward_hook) from a different loss function.

I’ve managed to do this with the caching procedure I’ve described above. However, it only works for one epoch. Which is odd, because PyTorch is telling me from the Traceback that it failed to set up the hooks but it’s clearly done it from the previous epoch? The Full traceback is as follows.

Traceback (most recent call last):
  File "test_hooks_double_loss.py", line 120, in <module>
    X, acceptance = sampler(burn_in)
  File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1065, in _call_impl
    return forward_call(*input, **kwargs)
  File "~/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "~/Code/Samplers.py", line 94, in forward
    self.step()
  File "~/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "~/Code/Samplers.py", line 72, in step
    log_pdf_chains = self._log_pdf(self.chains).detach_()
  File "~/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "~/Code/Samplers.py", line 48, in _log_pdf
    return self.network(x)[1].mul(2).detach_()
  File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1065, in _call_impl
    return forward_call(*input, **kwargs)
  File "~/Code/Models.py", line 59, in forward
    log_envs = self.log_envelope(x0)
  File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1081, in _call_impl
    input = bw_hook.setup_input_hook(input)
  File "~/.local/lib/python3.8/site-packages/torch/utils/hooks.py", line 158, in setup_input_hook
    res, input_idx = self._apply_on_tensors(fn, args)
  File "~/.local/lib/python3.8/site-packages/torch/utils/hooks.py", line 142, in _apply_on_tensors
    raise RuntimeError("Error while setting up backward hooks. Please open "
RuntimeError: Error while setting up backward hooks. Please open an issue with a code sample to reproduce this.

Although, the Traceback states to make a reproducible example it seems quite difficultt given I’ve probably got near 1000 lines of code over numerous files. Would sharing perhaps parts of main.py with my caching be an acceptable start to debugging this?

Any help with this would be greatly appreciated! :slight_smile:

Thank you!

Edit: I assume there’s something wrong with self.log_envelope(x0)? The source code for that layer is given below,

class LogEnvelope(nn.Module):

  def __init__(self, num_particles, num_dets, bias=False) -> None:
    super(LogEnvelope, self).__init__()
    
    self.num_particles = num_particles
    self.num_dets = num_dets
    
    self.weight = nn.Parameter(torch.empty(self.num_dets, self.num_particles))
    self.reset_parameters()

  def reset_parameters(self) -> None:
    torch.nn.init.uniform_(self.weight, 1., 2.) 

  def forward(self, x0):
    return torch.einsum("bin,di->bdin", x0, self.weight).pow(2).sum(dim=-1)  

Thanks for reporting this! Could you create an issue on GitHub with your code snippet (even if it’s long) to get the visibility of the code owners, please?

1 Like

Hi @ptrblck! I can do, I’ll try and reduce it into a slimmer file (and hopefully just 1 file, although I can’t promise anything!) and open a new issue on github! Thanks!

1 Like

Hi @ptrblck,

During the process of creating a standalone script, I found the error and it truly is bizarre. The problem emerges from the file Samplers.py. It’s effectively a module that uses a Network to calculate input data in accordance to a given probability by utilizing the Metropolis-Hastings algorithm and a random walk.

In short, and to add some context, it generates input data for my Feed-Forward Neural Network (Which is an R^N → R^1 function) by creating input data which is distributed to the output of the network squared.

input_data, acceptance  = Sampler(10) #10 steps of Metropolis-Hastings (input_data is shape [B,N])
                                      #input_data is distributed according to net(input_data).pow(2)

output = net(input_data) #my R^N -> R^1 function 

The line of code that caused the crash is shown below,

  @torch.no_grad()
  def forward(self, burn_in: int) -> Tuple[Tensor, Tensor]:
    self.acceptance.zero_()
    
    for i in range(burn_in+1): #check +1 here
      self.step()
    #return self.chains.detach_().requires_grad_(), self.acceptance.detach_() #crashes
    return self.chains.detach_(), self.acceptance.detach_() #doesn't crash

I called requires_grad_() on the data that’s returned from the Sampler as I need a gradient for it within my loss function. I do enable requires_grad_() within my loss function to check it’s enabled, so I can remove requires_grad_() within the Sampler and still calculate my loss. But to me it seems weird that this usage causes such a weird error.

Thank you!

Thanks for the update!
Do you require gradients in the samples themselves or are you also training the network used in the sampler?
Also, your current forward method is using the no_grad context manager. I assume the forward is used in the sampler network?
If so, does the code also raise the error, if you call detach_().requires_grad_() outside of the forward on the returned samples?

I’m training the network within the sampler. (I do require gradients for the samples, but it’s a 2nd order derivative and I’m following this neat trick here which reshapes the input and ensure they require_grads)

Yes, it is used. This method creates my input data, which to give it a more mathematical context rather than purely computational, are a set of positions within some N-dimensional vector space. These positions are initially randomly distributed (via torch.randn then via the Metropolis-Hastings algorthim are distrubted according to my probability distribution. I don’t require any gradient computations here, only function evaluations so I decided to run the forward method with a no_grad context manager in order to 1) get a speed-up and 2) not backprop through the sampling algorithm when I call my loss function!

Does that make things clearer as to why I implemented the forward method like it that way? (If there’s a mistake, please let me know!)

Good idea. I’ll have a look and follow up with this!

So I moved the detach_().requires_grad_() outside the forward method and it crashes like before. I’ve tried a few different ways of doing this and I think I have an idea as to what’s causing it.

@torch.no_grad()
def forward(self, burn_in: int) -> Tuple[Tensor, Tensor]:
  self.acceptance.zero_()
    
  for i in range(burn_in+1): #check +1 here
    self.step()
  #return self.chains.detach_().requires_grad_(), self.acceptance.detach_() #doesn't work
  return self.chains.detach_(), self.acceptance.detach_() #works
  return self.chains.detach().requires_grad_(), self.acceptance.detach_() #works

So it seems perhaps the issue arises from calling an in-place op on an in-place op?

Hi @ptrblck ,

I managed to make a reproducible example in less than 25 lines :smiley:

import torch
import torch.nn as nn

#The same as my metropolis-hasting algorthim (albeit shorter)
x = torch.randn(4,2,requires_grad=True)
fc = nn.Linear(2, 1)
x.detach_().requires_grad_()

def _save_output(module, grad_input, grad_output) -> None:
  print("grad_output[0]: ",grad_output[0])

fc.register_full_backward_hook(_save_output)

with torch.no_grad():
  y = fc(x)

z = x.pow(2).sum(dim=-1)

#loss is defined as the sum of 2 terms
#first term, y has no gradient to it (purely scales)
#second term, z has a gradient. (in short, y scales z)
loss = torch.sum(y*z)

loss.backward()

when running this example, it returns,

Traceback (most recent call last):
  File "backward_hook_setup_error_standalone.py", line 15, in <module>
    y = fc(x)
  File "~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1081, in _call_impl
    input = bw_hook.setup_input_hook(input)
  File "~/.local/lib/python3.8/site-packages/torch/utils/hooks.py", line 158, in setup_input_hook
    res, input_idx = self._apply_on_tensors(fn, args)
  File "~/.local/lib/python3.8/site-packages/torch/utils/hooks.py", line 142, in _apply_on_tensors
    raise RuntimeError("Error while setting up backward hooks. Please open "
RuntimeError: Error while setting up backward hooks. Please open an issue with a code sample to reproduce this.

Shall I open an issue with this?

Great! Yes, please create an issue using this code snippet! :slight_smile:

1 Like

For posterity - this bug has been fixed in torch==1.10

So you want:

import torch
from packaging import version
[...]
    def _register_backward_hook(self, module):
        if version.parse(torch.__version__) >= version.parse("1.10"):
            module.register_full_backward_hook(self.backward_hook)
        else:
            module.register_backward_hook(self.backward_hook)