Where did I leak my GPU Memory? 18 hours Debug but to no avail

I have trouble loading batch size of 32 for the pretrained model swin_large_patch4_window7_224 from the famous timm library. With mixed precision on, other people’s (I will link their code) pipeline can easily fit 32 batch size without CUDA OOM, but mine cannot. In order to debug this, I first ran through my code line by line, and trying various ways such as calling optimizer and scheduler outside my class, but it did not work.

After browsing through PyTorch forums, I found out about how to check line by line over where my OOM occured using the function below.

def print_gpu_usage():
    # divide by 1e+9 to get gb
    print(torch.cuda.memory_allocated()/ 1e+9)
    print(torch.cuda.max_memory_allocated()/1e+9)

I then put this function call in potential places in both notebooks (mine and the working one by others). To try to make the working notebook as similar as possible, I used the exact same model class and dataset class to ensure I call the tensors the same way, and also pass through the same model. The difference is soon apparent, but I have no idea why this can happen. I run both code up until where we call model(images) in which I think the issue happened. The numbers below are the GPU memory at each stage, somehow, when calling the exact place, feature_logits = self.extract_features(image), the working one uses only 10 GB while the non-working one uses almost 13 GB.

The traceback can be found in both notebooks and both can be run in under 1 minute if someone is kind enough to look at it. I fully ackowledge my inability to understand optimizing GPU memory issues at a deep level.

Non-working Version
Working Version

# WORKING VERSION
1st step: Loading Tensors In Forward Function
0.802948096
0.802948096
Sending Tensors to the function Extract Features --- 
0.802948096
0.802948096
image dtype : torch.float32
image shape : torch.Size([32, 3, 224, 224])
After Extract Features ---
feature_logits dtype : torch.float32
feature_logits shape : torch.Size([32, 1536])
10.382237184
10.454902784
After Extracting Features and Before calling Head function ---
10.382237184
10.454902784
1st step: Loading Tensors In Forward Function
3.159162368
10.454902784
Sending Tensors to the function Extract Features --- 
3.159162368
10.454902784
image dtype : torch.float32
image shape : torch.Size([32, 3, 224, 224])


# MY OOM VERSIOn
1st step: Loading Tensors In Forward Function
0.802948096
0.802948096
Sending Tensors to the function Extract Features --- 
0.802948096
0.802948096
image dtype : torch.float32
image shape : torch.Size([32, 3, 224, 224])
After Extract Features ---
feature_logits dtype : torch.float32
feature_logits shape : torch.Size([32, 1536])
12.599565312
12.62871552
After Extracting Features and Before calling Head function ---
12.599565312
12.62871552
1st step: Loading Tensors In Forward Function
12.599565824
12.62871552
Sending Tensors to the function Extract Features --- 
12.599565824
12.62871552
image dtype : torch.float32
image shape : torch.Size([32, 3, 224, 224])

The MODEL class for both notebooks are made to the same:

class PetNet(torch.nn.Module):
    def __init__(
        self,
        model_name: str = "swin_large_patch4_window7_224",
        out_features: int = 1,
        in_channels: int = 3,
        pretrained: bool = True,
    ):
        """Construct a new model.
        Args:
            model_name ([type], str): The name of the model to use. Defaults to MODEL_PARAMS.model_name.
            out_features ([type], int): The number of output features, this is usually the number of classes, but if you use sigmoid, then the output is 1. Defaults to MODEL_PARAMS.output_dimension.
            in_channels ([type], int): The number of input channels; RGB = 3, Grayscale = 1. Defaults to MODEL_PARAMS.input_channels.
            pretrained ([type], bool): If True, use pretrained model. Defaults to MODEL_PARAMS.pretrained.
        """
        super().__init__()

        self.in_channels = in_channels
        self.pretrained = pretrained

        self.backbone = timm.create_model(
            model_name, pretrained=self.pretrained, in_chans=self.in_channels
        )
    

        # removes head from backbone
        self.backbone.reset_classifier(num_classes=0, global_pool="avg")

        # get the last layer's number of features in backbone (feature map)
        self.in_features = self.backbone.num_features
        self.out_features = out_features


        self.single_head_fc = torch.nn.Sequential(
            torch.nn.Linear(self.in_features, self.out_features),
        )
        self.architecture: Dict[str, Callable] = {
            "backbone": self.backbone,
            "bottleneck": None,
            "head": self.single_head_fc,
        }

    def extract_features(self, image: torch.FloatTensor) -> torch.FloatTensor:
        """Extract the features mapping logits from the model.
        This is the output from the backbone of a CNN.
        Args:
            image (torch.FloatTensor): The input image.
        Returns:
            feature_logits (torch.FloatTensor): The features logits.
        """
        
        print("Sending Tensors to the function Extract Features --- ")
        print_gpu_usage()
        print(f"image dtype : {image.dtype}\nimage shape : {image.shape}")
        
        feature_logits = self.architecture["backbone"](image)
        
        print("After Extract Features ---")
        print(f"feature_logits dtype : {feature_logits.dtype}\nfeature_logits shape : {feature_logits.shape}")
        print_gpu_usage()
        
        return feature_logits

    def forward(self, image: torch.FloatTensor) -> torch.FloatTensor:
        """The forward call of the model.
        Args:
            image (torch.FloatTensor): The input image.
        Returns:
            classifier_logits (torch.FloatTensor): The output logits of the classifier head.
        """
        
        print("1st step: Loading Tensors In Forward Function")
        print_gpu_usage()
        
        feature_logits = self.extract_features(image)
        
        print("After Extracting Features and Before calling Head function ---")
        print_gpu_usage()
        classifier_logits = self.architecture["head"](feature_logits)

        return classifier_logits

I don’t think you are seeing a memory leak, but are allocating additional memory.
E.g. in your train_one_epoch you are using a pure FP32 forward pass and another one wrapped in autocast:

        # Iterate over train batches
        for step, data in enumerate(train_bar, start=1):

            # unpack
            inputs = data["X"].to(self.device, non_blocking=True)
            # .view(-1, 1) if BCELoss
            targets = data["y"].to(self.device, non_blocking=True).view(-1, 1)
            logits = self.model(inputs) # Forward pass logits

            batch_size = inputs.shape[0]
           
            if self.params.use_amp:
                self.optimizer.zero_grad()
                with torch.cuda.amp.autocast(enabled=True):
                    logits = self.model(inputs)  # Forward pass logits
                    #criterion = RMSELoss()
                    #curr_batch_train_loss = criterion(logits, targets)
                    curr_batch_train_loss = self.train_criterion(
                        targets,
                        logits,
                        CRITERION_PARAMS,
                    )

which would cause the additional memory usage as two computation graphs are computed.

Thanks @ptrblck for the response on this! I also made the mistake of calling model twice.