Can't find solution! RuntimeError: Trying to backward through the graph a second time

The detail error is:

Traceback (most recent call last):
  File "/root/DRIVE/main.py", line 147, in <module>
    loss.backward()
  File "/root/miniconda3/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/root/miniconda3/lib/python3.8/site-packages/torch/autograd/function.py", line 253, in apply
    return user_fn(self, *args)
  File "/root/miniconda3/lib/python3.8/site-packages/spikingjelly/activation_based/surrogate.py", line 1639, in backward
    return leaky_k_relu_backward(grad_output, ctx.saved_tensors[0], ctx.leak, ctx.k)
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I tried to add retain_graph=True, but a new error happened:

Traceback (most recent call last):
  File "/root/DRIVE/main.py", line 147, in <module>
    loss.backward(retain_graph=True)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [32]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

And I add torch.autograd.set_detect_anomaly(True):

/root/miniconda3/lib/python3.8/site-packages/torch/autograd/__init__.py:173: UserWarning: Error detected in CudnnBatchNormBackward0. Traceback of forward call that caused the error:
  File "/root/DRIVE/main.py", line 143, in <module>
    outputs = s_model(inputs)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/DRIVE/spiking_unet.py", line 115, in forward
    x = self.up4(x, x1)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/DRIVE/spiking_unet.py", line 63, in forward
    x = self.conv(x)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/DRIVE/spiking_unet.py", line 23, in forward
    t = self.c(x)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/miniconda3/lib/python3.8/site-packages/spikingjelly/activation_based/layer.py", line 465, in forward
    return functional.seq_to_ann_forward(x, super().forward)
  File "/root/miniconda3/lib/python3.8/site-packages/spikingjelly/activation_based/functional.py", line 686, in seq_to_ann_forward
    y = stateless_module(y)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 168, in forward
    return F.batch_norm(
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 2421, in batch_norm
    return torch.batch_norm(
 (Triggered internally at  ../torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/root/DRIVE/main.py", line 147, in <module>
    loss.backward(retain_graph=True)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [32]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Here is my training code:

    loss_weight = torch.as_tensor([1.0, 2.0], device=device)
    optimizer = torch.optim.Adam(s_model.parameters(), lr=lr)
    with torch.autograd.set_detect_anomaly(True):
        for i in range(epoch):
            for inputs, labels in train_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = s_model(inputs)
                optimizer.zero_grad()
                # loss = criterion(outputs, labels, loss_weight, dice=False, num_classes=num_classes, ignore_index=255)
                loss = nn.functional.cross_entropy(outputs, labels, ignore_index=255, weight=loss_weight)
                loss.backward(retain_graph=True)
                optimizer.step()
                print(f'epoch {i}: loss = {loss}')

Here is my model:

import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, functional, surrogate, layer
import torch.nn.functional as F
from typing import Dict


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super(DoubleConv, self).__init__()
        if mid_channels is None:
            mid_channels = out_channels
        self.c = nn.Sequential(
            layer.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            layer.BatchNorm2d(mid_channels),
            neuron.IFNode(surrogate_function=surrogate.LeakyKReLU(leak=0)),
            layer.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            layer.BatchNorm2d(out_channels),
            neuron.IFNode(surrogate_function=surrogate.LeakyKReLU(leak=0)),
        )

    def forward(self, x):
        t = self.c(x)
        return t



class Down(nn.Module):
    def __init__(self, in_channels, out_channels, step_mode='m'):
        super(Down, self).__init__()
        self.pool = layer.MaxPool2d(2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
        functional.set_step_mode(self, step_mode=step_mode)

    def forward(self, inputs):
        t = self.pool(inputs)
        tt = self.conv(t)
        return tt


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=False, step_mode='m'):
        super(Up, self).__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = layer.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
        functional.set_step_mode(self, step_mode=step_mode)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)
        # [T, N, C, H, W]
        diff_y = x2.size()[3] - x1.size()[3]
        diff_x = x2.size()[4] - x1.size()[4]

        # padding_left, padding_right, padding_top, padding_bottom
        x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2])

        x = torch.cat([x2, x1], dim=2)
        x = self.conv(x)
        return x


class OutConv(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(OutConv, self).__init__()
        self.c = layer.Conv2d(in_channels, num_classes, kernel_size=1)


    def forward(self, x):
        temp = self.c(x)
        return temp


class S_UNet(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 num_classes: int = 2,
                 bilinear: bool = False,
                 base_c: int = 64,
                 step_mode: str = 'm',
                 T: int = 4):
        super(S_UNet, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.bilinear = bilinear
        self.T = T

        self.in_conv = DoubleConv(in_channels, base_c)
        self.down1 = Down(base_c, base_c * 2, step_mode)
        self.down2 = Down(base_c * 2, base_c * 4, step_mode)
        self.down3 = Down(base_c * 4, base_c * 8, step_mode)
        factor = 2 if bilinear else 1
        self.down4 = Down(base_c * 8, base_c * 16 // factor, step_mode)
        self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear, step_mode)
        self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear, step_mode)
        self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear, step_mode)
        self.up4 = Up(base_c * 2, base_c, bilinear, step_mode)
        self.out_conv = OutConv(base_c, num_classes)
        functional.set_step_mode(self, step_mode=step_mode)

    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        x = x.repeat(self.T, 1, 1, 1, 1)
        x1 = self.in_conv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.out_conv(x).mean(0)

        return logits

My model is a UNet with spiking nerual network, but I think that this error have nothing to do with the SNN things. This has been bothering me for a long time. I’d appreciate it if you could help me.

Hi Met!

Spiking neural networks maintain some state. If in your SNN implementation,
that state hangs on to part of pytorch’s computation graph, it could plausibly
cause the error that you’re seeing.

It appears that this may be what’s happening with the SpikingJelly package
you are using. Quoting from the SpikingJelly documentation:

For convenience, we can also call spikingjelly.activation_based.functional.reset_net to reset all modules in a network.

If the network uses one or more stateful modules, it must be reset after processing one batch of data during training and inference:

from spikingjelly.activation_based import functional
# ...
for x, label in tqdm(train_data_loader):
    # ...
    optimizer.zero_grad()
    y = net(x)
    loss = criterion(y, label)
    loss.backward()
    optimizer.step()

    functional.reset_net(net)
    # Never forget to reset the network!

If we forget to reset, we may get a wrong output during inference or an error during training:

RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.

Try adding functional.reset_net (s_model) after optimizer.step() in
your training loop.

Best.

K. Frank

1 Like

Hi Frank, you are right. After I add ‘functional.reset_net(s_model)’, everything goes well. My gratitude is beyond description. You save my day! :pray: