Memory Error with UNET implementation, on both CPU and GPU

Hi, I am a newbie to pytorch, I made the following pytorch implementation of pytorch, it seems correct but it always kills memory, both cuda and cpu.

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.duo = nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.ReLU(inplace=True),
                    nn.BatchNorm2d(out_channels),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.ReLU(inplace=True),
                    nn.BatchNorm2d(out_channels),
                )
        
    def forward(self, x):
        return self.duo(x)

class UNET1(nn.Module):
    def __init__(self, in_channel=3, out_channel=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        self.copy = []
        self.compilation = []
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.features = features
         
        in_feature = in_channel
        for feature in features:
            model = DoubleConv(in_feature, feature)
            in_feature = feature
            self.downs.append(model)
        self.baseconv = DoubleConv(feature, feature*2)

        for feature in reversed(features):
          conv = nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2, bias=False)
          self.ups.append(conv)
          self.ups.append(DoubleConv(feature*2, feature))
        self.last_conv = nn.Conv2d(features[0], out_channel, kernel_size=1, stride=1)
          
    def encoder(self, x):
        for down in self.downs:            
            x = down(x)
            self.copy.append(x)
            x = self.pool(x)
            self.compilation.append(x)
            
        self.copy = self.copy[::-1]
        return self.compilation

    def base(self):
      last = self.compilation[-1]
      basic = self.baseconv(last)
      self.compilation.append(basic)
      return basic

    def decoder(self):
        layer = self.compilation[-1]
        for enum in range(0, len(self.ups), 2):
          up_conv = self.ups[enum](layer)
          concatena = TF.center_crop(self.copy[enum//2], up_conv.shape[2:])
          layer = torch.cat((concatena, up_conv), dim=1)
          layer = self.ups[enum+1](layer)
        return self.last_conv(layer)

    def forward(self, x):
      self.encoder(x)
      self.base()
      res  = self.decoder()
      return res

Now, I copied this code online, using torchinfo.summary, they had similar memory usage, but when actually running code, this does okay in terms of memory while my implementation always crash out with a memory error after an epoch.

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET2(nn.Module):
    def __init__(
            self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

def test():
    x = torch.randn((3, 1, 161, 161))
    model = UNET(in_channels=1, out_channels=1)
    preds = model(x)
    assert preds.shape == x.shape

What am I missing please?

What kind of “memory error” do you get? Could you post the error message with the stack trace, please?

1 Like

RuntimeError: CUDA out of memory. Tried to allocate 76.00 MiB (GPU 0; 11.17 GiB total capacity; 10.41 GiB already allocated; 64.81 MiB free; 10.58 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

My implementation gives this, the other runs without error.

Thank you.

If you check the number of parameters via sum([p.nelement() for p in model1.parameters()]) you’ll see that a small difference is already visible as: 31036673 vs. 31037633, which points to a difference in the module initialization. Based on this small difference I would guess that a bias parameter is skipped in one model implementation and won’t explain the difference in memory usage.
To further narrow it down, check the shapes of all intermediate activation tensors (either in the forward methods directly or via forward hooks) and compare these.

I am not using biases, since I am using BatchNorm after the convolutional layers, therefore, the working model has slightly more parameters than mine.

could it be because of the additional self.copy and self.compilations in the UNET1 class?, that is the only visible difference.

Thanks again.

Yes, appending tensors to a list would increase the memory usage and as you’ve pointed out the usage of self.copy and self.compilation could cause the majority of the memory difference.

Thanks for the help, all I needed to do was to make copy and compilation local to the function and not class variables.

Thank you!

excuse me i had the same problem as i save the result in a list … how can i solve this problem ?

Saving results in a list without detaching them will not only store the actual tensor but also the entire computation graph if the tensor is still attached to it. It depends on your use case how to “solve” it, as the mentioned usage is not wrong by itself, if you explicitly want to store the computation graphs. If not, then detach the tensor before appending them to a list.

1 Like

thanks a lot for this clarification