Segmentation fault when loss.backward()

I have a complex model and there is a part in the model which has this class, the output of this class is (2,1280,208,208) and this is sent to a conv2d which produces (2, 10,208,208) and I am trying to do pixel-wise segmentation. When i do loss.backward() I get segmentation fault. I can see the segmentation fault happens in the cat function because when i return just layer1 or any single layer in CrossAttention1 class I do not get segmentation fault.

class CrossAttention1(nn.Module):
    def __init__(self, dim):
        super(CrossAttention1, self).__init__()
        self.up5 = nn.Upsample(scale_factor=8, mode='bilinear')
        self.up4 = nn.Upsample(scale_factor=8, mode='bilinear')
        self.up3 = nn.Upsample(scale_factor=4, mode='bilinear')
        self.up2 =nn.Upsample(scale_factor=2, mode='bilinear')

    def forward(self, layer1, layer2, layer3, layer4, layer5):
        layer2 = self.up2(layer2)
        layer3 = self.up3(layer3)
        layer4 = self.up4(layer4)
        layer5 = self.up5(layer5)
        x = torch.cat([layer1, layer2, layer3, layer4, layer5], 1)
        return x

Do you get any stack trace or just the segmentation fault?
Could you post the shapes of all layerX inputs?

I just got Segmentation fault (core dumped)
layer1 input shape : 2, 256, 208, 208
layer2 input shape : 2, 256, 104, 104
layer3 input shape : 2, 256, 52, 52
layer4 input shape : 2, 256, 26, 26
layer5 input shape : 2, 256, 26, 26

And when i tried with batch size 1 instead of 2 it is working fine. the problem starts with batch size > 1 and i also tried reducing the input size to check if does not explode the RAM but still the problem is same.

Your module works fine with these input shapes on my machine.
Could you post an executable code snippet reproducing this error?

I have created a gist https://gist.github.com/AshStuff/87cf8051e48da0a5f9d85f74a5d15c71 which reproduces the error. When i tried removing the checkpoint the code seems to work fine and at the same time when i remove model = torch.nn.DataParallel(model) it is working fine. No idea why.

Thanks a lot @Ashwin_Raju. I can reproduce the issue on PyTorch 1.0.0 and on PyTorch master.

And indeed, it’s the same error as https://github.com/pytorch/pytorch/issues/11732 (I checked the stack-trace).

I’m going to file it as a blocker for 1.0.1 which is due in ~15 days, and fix the issue.