Thanks for the quick response @albanD
Sure, I’m training a GAN in which I’m using a 3D U-Net-like generator that’s made partially invertible using MemCNN. The piece of code where I have this problem is
class DownTransition(nn.Module):
def __init__(self, inChans, nConvs):
super(DownTransition, self).__init__()
outChans = 2*inChans
self.down_conv_ab = self.build_down_conv(inChans, outChans)
self.down_conv_ba = self.build_down_conv(inChans, outChans)
self.core = nn.Sequential(*[RevBlock(outChans) for _ in range(nConvs)])
self.relu = nn.PReLU(outChans)
def build_down_conv(self, inChans, outChans):
return nn.Sequential(nn.Conv3d(inChans, outChans, kernel_size=2, stride=2),
nn.BatchNorm3d(outChans),
nn.PReLU(outChans))
def forward(self, x, inverse=False):
if inverse:
down_conv = self.down_conv_ba
core = reversed(self.core)
else:
down_conv = self.down_conv_ab
core = self.core
down = down_conv(x)
out = down # problematic line(s)
for block in core:
out = block(out, inverse=inverse)
out = out + down # tried as suggested, didn't help
return self.relu(out)
and the RevBlock
class is:
class RevBlock(nn.Module):
def __init__(self, nchan):
super(RevBlock, self).__init__()
invertible_module = memcnn.AdditiveCoupling(
Fm=self.build_conv_block(nchan//2),
Gm=self.build_conv_block(nchan//2)
)
self.rev_block = memcnn.InvertibleModuleWrapper(fn=invertible_module,
keep_input=True,
keep_input_inverse=True)
def build_conv_block(self, nchan):
block = nn.Sequential(nn.Conv3d(nchan, nchan, kernel_size=5, padding=2),
nn.BatchNorm3d(nchan),
nn.PReLU(nchan))
return block
def forward(self, x, inverse=False):
if inverse:
return self.rev_block.inverse(x)
else:
return self.rev_block(x)
Furthermore, you can see here how the forward and backward passes are implemented. I’m using that framework and I haven’t done any changes in that part of the code.
Finally, the traceback I get is as follow
Traceback (most recent call last):
File "train.py", line 35, in <module>
model.optimize_parameters()
File "/rwthfs/rz/cluster/home/ft002207/ibRevGAN/models/unpaired_revgan3d_model.py", line 158, in optimize_parameters
self.backward_G()
File "/rwthfs/rz/cluster/home/ft002207/ibRevGAN/models/unpaired_revgan3d_model.py", line 150, in backward_G
self.loss_G.backward()
File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
allow_unreachable=True) # allow_unreachable flag
File "/rwthfs/rz/cluster/home/ft002207/ibRevGAN/memcnn/models/revop.py", line 38, in backward_hook
temp_output.backward(gradient=grad_output)
File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
allow_unreachable=True) # allow_unreachable flag
File "/rwthfs/rz/cluster/home/ft002207/ibRevGAN/memcnn/models/revop.py", line 38, in backward_hook
temp_output.backward(gradient=grad_output)
File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/tensor.py", line 166, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/ft002207/anaconda3/envs/maastro/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.