I have a complex model and there is a part in the model which has this class, the output of this class is (2,1280,208,208) and this is sent to a conv2d which produces (2, 10,208,208) and I am trying to do pixel-wise segmentation. When i do loss.backward() I get segmentation fault. I can see the segmentation fault happens in the cat function because when i return just layer1 or any single layer in CrossAttention1 class I do not get segmentation fault.
And when i tried with batch size 1 instead of 2 it is working fine. the problem starts with batch size > 1 and i also tried reducing the input size to check if does not explode the RAM but still the problem is same.
I have created a gist https://gist.github.com/AshStuff/87cf8051e48da0a5f9d85f74a5d15c71 which reproduces the error. When i tried removing the checkpoint the code seems to work fine and at the same time when i remove model = torch.nn.DataParallel(model) it is working fine. No idea why.