GPU memory usage increases even when overwriting variables

I’m currently training a ResUnet with 3 encoding blocks, 1 bottleneck, 4 decoding blocks, and an output layer and I’m using an RTX 3090. I seem to be having trouble fitting the model, training, and validation within the available 24GB memory.

The model size itself isn’t that large, and I have a batch_input.shape() of (4,3,604,513) and a batch_target.shape() which is also (4,3,604,513). These also don’t take up that much space in memory.

Within my forward pass it seems the memory usage for the GPU seems to continue to increase even though I’m overwriting the x variable used for input and output of each layer - shown below

    def forward(self, x):
        skip_connections = []
        s = x.clone()
        """ initial conv block"""
        x = self.conv1(x) 
        x = self.bn1(x)  
        x = self.conv2(x) 
        s = self.conv_skip(s)  
        x = x + s
        skip_connections.append(x)
        """ Encoding blocks"""
        for encode in self.encodes:
            x = encode(x)
            skip_connections.append(x)
        """Bottle neck"""
        x = self.bottleneck(x)  
        """Decoder"""
        skip_connections = skip_connections[::-1]  
        for idx, decode in enumerate(self.decodes):
            x = decode(x, skip_connections[idx])
        """output"""
        x = self.output(x)
        return x

I understand that certain “versions” of x will be larger than others, but then when I return to my original output size the memory usage remains large. Is there something to do with how PyTorch cache’s memory on the GPU that I’m missing?

During training Autograd will store intermediate forward activations since these are needed for the gradient computation. If you don’t want to train the model or compute gradients you could disable this behavior by using with torch.no_grad() which will delete the intermediates.