Residual connection with detach().clone() and required_gradients=True

I do know that residual/skip connections can be implemented by simply doing

out = someOperation(x)
residual = x
out += residual
return out

but I am wondering if we have the same outcome by doing it in the following way

out = someOperation(x)

residual = x.detach().clone()
residual.requires_grad = True

out += residual
return out

Now, I know you’re asking yourself why would I even go into this trouble if the first thing works well, but, unfortunately, it doesn’t in my case. I believe it is because I am creating a partially invertible architecture and that the error might be caused by the implementation of the library for invertible architectures, but I am not completely sure since I do not understand it fully yet.

If my implementation follows the first case, I get

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

The problem is not in performing backwards() separately for losses or anything such, so that is why I assume it may be caused by the invertible architectures library.

However, by implementing it as shown in the second case (stumbled upon it here), it does train, but I am worried if the gradients are kept the way they should be, which is why I’m looking for the answer here. Thanks!

1 Like

Hi,

The residual.requires_grad = True is not really useful here as you won’t be able to access the computed grad.
And the detach() will break the graph, that means that no gradient will flow back your residual connections. Not sure how that would impact the training, but you won’t have gradients anymore for sure.

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

This can happen for a few reasons You will need to share some code.
Also does using out = out + residual solves your issue?

1 Like

Thanks for the quick response @albanD

Sure, I’m training a GAN in which I’m using a 3D U-Net-like generator that’s made partially invertible using MemCNN. The piece of code where I have this problem is

class DownTransition(nn.Module):
    def __init__(self, inChans, nConvs):
        super(DownTransition, self).__init__()
        outChans = 2*inChans
        self.down_conv_ab = self.build_down_conv(inChans, outChans)
        self.down_conv_ba = self.build_down_conv(inChans, outChans)
        self.core = nn.Sequential(*[RevBlock(outChans) for _ in range(nConvs)])
        self.relu = nn.PReLU(outChans)

    def build_down_conv(self, inChans, outChans):
        return nn.Sequential(nn.Conv3d(inChans, outChans, kernel_size=2, stride=2),
                             nn.BatchNorm3d(outChans),
                             nn.PReLU(outChans))

    def forward(self, x, inverse=False):
        if inverse:
            down_conv = self.down_conv_ba
            core = reversed(self.core)
        else:
            down_conv = self.down_conv_ab
            core = self.core

        down = down_conv(x)
        out = down  # problematic line(s)
        for block in core:
            out = block(out, inverse=inverse)
        
        out = out + down  # tried as suggested, didn't help
        return self.relu(out)

and the RevBlock class is:

class RevBlock(nn.Module):
    def __init__(self, nchan):
        super(RevBlock, self).__init__()
        
        invertible_module = memcnn.AdditiveCoupling(
            Fm=self.build_conv_block(nchan//2),
            Gm=self.build_conv_block(nchan//2)
        )

        self.rev_block = memcnn.InvertibleModuleWrapper(fn=invertible_module, 
                                                        keep_input=True, 
                                                        keep_input_inverse=True)

    def build_conv_block(self, nchan):
        block = nn.Sequential(nn.Conv3d(nchan, nchan, kernel_size=5, padding=2),
                              nn.BatchNorm3d(nchan),
                              nn.PReLU(nchan))
        return block

    def forward(self, x, inverse=False):
        if inverse:
            return self.rev_block.inverse(x)
        else:
            return self.rev_block(x)

Furthermore, you can see here how the forward and backward passes are implemented. I’m using that framework and I haven’t done any changes in that part of the code.

Finally, the traceback I get is as follow

Traceback (most recent call last):
  File "train.py", line 35, in <module>
    model.optimize_parameters()
  File "/rwthfs/rz/cluster/home/ft002207/ibRevGAN/models/unpaired_revgan3d_model.py", line 158, in optimize_parameters
    self.backward_G()
  File "/rwthfs/rz/cluster/home/ft002207/ibRevGAN/models/unpaired_revgan3d_model.py", line 150, in backward_G
    self.loss_G.backward()
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File "/rwthfs/rz/cluster/home/ft002207/ibRevGAN/memcnn/models/revop.py", line 38, in backward_hook
    temp_output.backward(gradient=grad_output)
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File "/rwthfs/rz/cluster/home/ft002207/ibRevGAN/memcnn/models/revop.py", line 38, in backward_hook
    temp_output.backward(gradient=grad_output)
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

So the most likely problem is that you perform some operations in a differentiable manner outside of your training loop. And so these operations are part of the computational graph for multiple iterations.
Could you give a small (30/40 lines) that I can use to reproduce this?

1 Like

Hi @albanD! Thanks for helping, here’s a minimal example with which you can reproduce the error, it is slightly longer though, 70 lines.

  • PyTorch: 1.3.1
  • MemCNN: 1.2.1 (latest) - pip install memcnn

Code:

import torch
from torch import nn, optim
import memcnn


class RevBlock(nn.Module):
    def __init__(self, nchan):
        super(RevBlock, self).__init__()

        invertible_module = memcnn.AdditiveCoupling(
            Fm=self.build_conv_block(nchan//2),
            Gm=self.build_conv_block(nchan//2)
        )

        self.rev_block = memcnn.InvertibleModuleWrapper(fn=invertible_module, 
                                                        keep_input=True, 
                                                        keep_input_inverse=True)

    def build_conv_block(self, nchan):
        return nn.Sequential(nn.Conv3d(nchan, nchan, kernel_size=5, padding=2),
                             nn.BatchNorm3d(nchan),
                             nn.PReLU(nchan))
        
    def forward(self, x, inverse=False):
        if inverse:
            return self.rev_block.inverse(x)
        else:
            return self.rev_block(x)


class DownTransition(nn.Module):
    def __init__(self, inChans, nConvs):
        super(DownTransition, self).__init__()
        outChans = 2*inChans
        self.down_conv_ab = self.build_down_conv(inChans, outChans)
        self.down_conv_ba = self.build_down_conv(inChans, outChans)
        self.core = nn.Sequential(*[RevBlock(outChans) for _ in range(nConvs)])
        self.relu = nn.PReLU(outChans)

    def build_down_conv(self, inChans, outChans):
        return nn.Sequential(nn.Conv3d(inChans, outChans, kernel_size=2, stride=2),
                             nn.BatchNorm3d(outChans),
                             nn.PReLU(outChans))

    def forward(self, x, inverse=False):
        if inverse:
            down_conv = self.down_conv_ba
            core = reversed(self.core)
        else:
            down_conv = self.down_conv_ab
            core = self.core

        down = down_conv(x)
        out = down
        for block in core:
            out = block(out, inverse=inverse)
        
        out = out + down
        return self.relu(out)


model = DownTransition(16, 2)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for i in range(10):
    optimizer.zero_grad()
    data, target = torch.rand((2,16,64,64,64)), torch.rand((2,32,32,32,32))
    out = model.forward(data)
    loss = criterion(out, target)
    loss.backward()
    optimizer.step()

Traceback:

Traceback (most recent call last):
  File "minimal.py", line 71, in <module>
    loss.backward()
  File "/home/bro/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/bro/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

@albanD sorry for bothering you, I’m wondering if you had time to look into it? Thanks!

Hi,

I did check but I did not really had any useful update :confused:
In particular, I have no idea how memcnn works and the error comes from there as removing it fixes the issue.
I guess you want to do something similar to what is explained in this post: Mysterious `trying to backward through the graph a second time' issue to try and figure out where the problem comes from :slight_smile:

1 Like

Hi @albanD, I appreciate it, I did find out however that it works with PyTorch 1.1.0, so I guess I’ll go with that for the moment :smiley:

Well, you might want to be careful and double check that the computed gradients are correct for a simple case.
We usually only introduce these errors at places where it used to do bad things :smiley:

1 Like