Torch allocating 103GB

I was trying to make a UNet-like architecture using PyTorch. I got a Runtime Error saying that there is not enough memory as I’m trying to allocate 103GB.

I am using a CPU only. And I have shrunk my input to [2,2,2] with a the batch_size=1, with no change to the error message.
Can anyone help me figure out why this is happening and how to fix it?

Hi!
Can you please provide code of your model, at first?
It may help to figure out what is the cause of the problem

class _EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=False):
        super(_EncoderBlock, self).__init__()
            
        self.e_mxpl1 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.e_conv1 = nn.Conv3d(in_channels, out_channels,kernel_size=3),#Conv3dSep(in_channels, out_channels),
        self.e_bn1 = nn.BatchNorm3d(out_channels),
        self.e_relu1 = nn.ReLU(inplace=True),
        self.e_conv2 = nn.Conv3d(out_channels, out_channels,kernel_size=3),#Conv3dSep(out_channels, out_channels),
        self.e_bn2 = nn.BatchNorm3d(out_channels),
        self.e_relu2 =nn.ReLU(inplace=True),
        self.e_drp = nn.Dropout()

    def forward(self, x):
        x = self.e_mxpl1(x)
        x = self.e_conv1(x)
        x = self.e_bn1(x)
        x = self.e_relu1(x)
        x = self.e_conv2(x)
        x = self.e_bn2(x)
        x = self.e_relu2(x)
        if dropout:
            x = self.e_drp(x)
        return x

class _DecoderBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super(_DecoderBlock, self).__init__()
        self.d_conv1 = nn.Conv3d(in_channels, middle_channels,kernel_size=3),#Conv3dSep(in_channels, middle_channels),
        self.d_bn1 = nn.BatchNorm3d(middle_channels),
        self.d_relu1 = nn.ReLU(inplace=True),
        self.d_conv2 = nn.Conv3d(middle_channels, middle_channels,kernel_size=3),#Conv3dSep(middle_channels, middle_channels),
        self.d_bn2 = nn.BatchNorm3d(middle_channels),
        self.d_relu2 = nn.ReLU(inplace=True),
        self.d_convT = nn.ConvTranspose3d(middle_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        x = self.d_conv1(x)
        x = self.d_bn1(x)
        x = self.d_relu1(x)
        x = self.d_conv2(x)
        x = self.d_bn2(x)
        x = self.d_relu2(x)
        x = self.d_convT(x)
        return x

class WNet3D(nn.Module):
    def __init__(self, num_classes=1):
        super(WNet3D, self).__init__()
        self.module_1 = nn.Sequential(
            nn.Conv3d(1, 64,kernel_size=3),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.Conv3d(64, 64,kernel_size=3),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
        )
        self.module_2 = _EncoderBlock(64, 128)
        self.module_3 = _EncoderBlock(128, 256)
        self.module_4 = _EncoderBlock(256, 512, dropout=True)
        self.mxpl_enc = nn.MaxPool3d(kernel_size=2, stride=2)
        self.module_5 = _DecoderBlock(512, 1024, 512)
        self.module_6 = _DecoderBlock(1024, 512, 256)
        self.module_7 = _DecoderBlock(512, 256, 128)
        self.module_8 = _DecoderBlock(256, 128, 64)
        self.module_9 = nn.Sequential(
            nn.Conv3d(128,64,kernel_size=3),#Conv3dSep(128, 64),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.Conv3d(64,64,kernel_size=3),#Conv3dSep(64, 64),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
        )
        self.mid = nn.Conv3d(64, num_classes, kernel_size=1)
        self.sm = nn.Softmax(dim=4)
        self.module_10 = _EncoderBlock(1,64)
        self.module_11 = _EncoderBlock(64,128)
        self.module_12 = _EncoderBlock(128, 256)
        self.module_13 = _EncoderBlock(256, 512, dropout=True)
        self.mxpl_dec = nn.MaxPool3d(kernel_size=2, stride=2)
        self.module_14 = _DecoderBlock(512, 1024, 512)
        self.module_15 = _DecoderBlock(1024, 512, 256)
        self.module_16 = _DecoderBlock(512, 256, 128)
        self.module_17 = _DecoderBlock(256, 128, 64)
        self.module_18 = nn.Sequential(
            nn.Conv3d(128,64,kernel_size=3),#Conv3dSep(128, 64),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.Conv3d(64,64,kernel_size=3),#Conv3dSep(64, 64),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
        )
        self.final = nn.Conv3d(64, num_classes, kernel_size=1)
        # self._initialize_weights()


    def forward(self, x):
        enc1 = self.module_1(x)
        enc2 = self.module_2(enc1)
        enc3 = self.module_3(enc2)
        enc4 = self.module_4(enc3)
        enc4 = self.mxpl_enc(enc4)
        center = self.module_5(enc4)
        dec4 = self.module_6(torch.cat([center, F.upsample(enc4, center.size()[2:], mode='bilinear')], 1))
        dec3 = self.module_7(torch.cat([dec4, F.upsample(enc3, dec4.size()[2:], mode='bilinear')], 1))
        dec2 = self.module_8(torch.cat([dec3, F.upsample(enc2, dec3.size()[2:], mode='bilinear')], 1))
        dec1 = self.module_9(torch.cat([dec2, F.upsample(enc1, dec2.size()[2:], mode='bilinear')], 1))
        mid = self.sm(self.mid(dec1))
        middle = F.upsample(mid, x.size()[2:], mode='bilinear')

        enc5 = self.module_10(middle)
        enc6 = self.module_11(enc5)
        enc7 = self.module_12(enc6)
        enc8 = self.module_13(enc7)
        enc8 = self.mxpl_dec(enc8)
        center = self.module_14(enc8)
        dec8 = self.module_15(torch.cat([center, F.upsample(enc8, center.size()[2:], mode='bilinear')], 1))
        dec7 = self.module_16(torch.cat([dec8, F.upsample(enc7, dec8.size()[2:], mode='bilinear')], 1))
        dec6 = self.module_17(torch.cat([dec7, F.upsample(enc6, dec7.size()[2:], mode='bilinear')], 1))
        dec5 = self.module_18(torch.cat([dec6, F.upsample(enc5, dec6.size()[2:], mode='bilinear')], 1))
        fin = self.final(dec5)
        final = F.upsample(fin, x.size()[2:], mode='bilinear')

        return middle,final

Thanks!
And what is the size of your input tensor?

I scaled it to (1,256,256,256)