Multi-head Architecture faces CUDA out of memory

I am running a multi-label classification task with 14 classes.
When I use a resnet34 with 128 batch size and 128x128 input image everything is ok on a 8GB gpu. The output size is [batch_size, 14] and I am using F.binary_cross_entropy_with_logits loss for this multilabel classification task.

When I change the architecture to be a multi-head network I face runtime error: CUDA out of memory. The multihead architecture is using resnet34 as the base and MLPs as heads. I even changed the MLP layers to be of size 1 and I changed the batch_size from 128 to 16 but still getting memory error.

Here’s the code:

import torch
import torch.nn as nn
from .resnet import *


class Head(nn.Module):
    def __init__(self, head_elements, num_classes=1):
        super().__init__()
        self.fc1 = nn.Linear(head_elements, 1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(1, num_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


class Multihead_Resnet(nn.Module):
    def __init__(self, num_classes, device, head_elements=1, base_model='resnet34'):
        super().__init__()
        self.device = device
        self.num_classes = num_classes
        self.head_elements = head_elements
        self.base_model_out_size = num_classes * head_elements
        if base_model == 'resnet34':
            self.base_model = resnet34(num_classes=self.base_model_out_size)

        self.relu = nn.ReLU(inplace=True)
        self.heads = nn.ModuleList([Head(head_elements=head_elements) for i in range(num_classes)])
        



    def forward(self, x):
        x = self.base_model(x)
        x = self.relu(x)
        y = torch.zeros(x.shape[0], self.num_classes, requires_grad=True, device=self.device)
        for i in range(self.num_classes):
            y[:, i] = self.heads[i](x[:, i*self.head_elements:i*self.head_elements+self.head_elements]).squeeze()
        
        return y

I printed out the model and a summary of it:

loading dataset...
len(train_loader.dataset): 86524
train_dataset_array.shape: torch.Size([16, 1, 128, 128])
building model...
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─ResNet: 1-1                            [-1, 14]                  --
|    └─Conv2d: 2-1                       [-1, 64, 64, 64]          3,136
|    └─BatchNorm2d: 2-2                  [-1, 64, 64, 64]          128
|    └─ReLU: 2-3                         [-1, 64, 64, 64]          --
|    └─MaxPool2d: 2-4                    [-1, 64, 32, 32]          --
|    └─Sequential: 2-5                   [-1, 64, 32, 32]          --
|    |    └─BasicBlock: 3-1              [-1, 64, 32, 32]          73,984
|    |    └─BasicBlock: 3-2              [-1, 64, 32, 32]          73,984
|    |    └─BasicBlock: 3-3              [-1, 64, 32, 32]          73,984
|    └─Sequential: 2-6                   [-1, 128, 16, 16]         --
|    |    └─BasicBlock: 3-4              [-1, 128, 16, 16]         230,144
|    |    └─BasicBlock: 3-5              [-1, 128, 16, 16]         295,424
|    |    └─BasicBlock: 3-6              [-1, 128, 16, 16]         295,424
|    |    └─BasicBlock: 3-7              [-1, 128, 16, 16]         295,424
|    └─Sequential: 2-7                   [-1, 256, 8, 8]           --
|    |    └─BasicBlock: 3-8              [-1, 256, 8, 8]           919,040
|    |    └─BasicBlock: 3-9              [-1, 256, 8, 8]           1,180,672
|    |    └─BasicBlock: 3-10             [-1, 256, 8, 8]           1,180,672
|    |    └─BasicBlock: 3-11             [-1, 256, 8, 8]           1,180,672
|    |    └─BasicBlock: 3-12             [-1, 256, 8, 8]           1,180,672
|    |    └─BasicBlock: 3-13             [-1, 256, 8, 8]           1,180,672
|    └─Sequential: 2-8                   [-1, 512, 4, 4]           --
|    |    └─BasicBlock: 3-14             [-1, 512, 4, 4]           3,673,088
|    |    └─BasicBlock: 3-15             [-1, 512, 4, 4]           4,720,640
|    |    └─BasicBlock: 3-16             [-1, 512, 4, 4]           4,720,640
|    └─AdaptiveAvgPool2d: 2-9            [-1, 512, 1, 1]           --
|    └─Linear: 2-10                      [-1, 14]                  7,182
├─ReLU: 1-2                              [-1, 14]                  --
├─ModuleList: 1                          []                        --
|    └─Head: 2-11                        [-1, 1]                   --
|    |    └─Linear: 3-17                 [-1, 1]                   2
|    |    └─ReLU: 3-18                   [-1, 1]                   --
|    |    └─Linear: 3-19                 [-1, 1]                   2
|    └─Head: 2-12                        [-1, 1]                   --
|    |    └─Linear: 3-20                 [-1, 1]                   2
|    |    └─ReLU: 3-21                   [-1, 1]                   --
|    |    └─Linear: 3-22                 [-1, 1]                   2
|    └─Head: 2-13                        [-1, 1]                   --
|    |    └─Linear: 3-23                 [-1, 1]                   2
|    |    └─ReLU: 3-24                   [-1, 1]                   --
|    |    └─Linear: 3-25                 [-1, 1]                   2
|    └─Head: 2-14                        [-1, 1]                   --
|    |    └─Linear: 3-26                 [-1, 1]                   2
|    |    └─ReLU: 3-27                   [-1, 1]                   --
|    |    └─Linear: 3-28                 [-1, 1]                   2
|    └─Head: 2-15                        [-1, 1]                   --
|    |    └─Linear: 3-29                 [-1, 1]                   2
|    |    └─ReLU: 3-30                   [-1, 1]                   --
|    |    └─Linear: 3-31                 [-1, 1]                   2
|    └─Head: 2-16                        [-1, 1]                   --
|    |    └─Linear: 3-32                 [-1, 1]                   2
|    |    └─ReLU: 3-33                   [-1, 1]                   --
|    |    └─Linear: 3-34                 [-1, 1]                   2
|    └─Head: 2-17                        [-1, 1]                   --
|    |    └─Linear: 3-35                 [-1, 1]                   2
|    |    └─ReLU: 3-36                   [-1, 1]                   --
|    |    └─Linear: 3-37                 [-1, 1]                   2
|    └─Head: 2-18                        [-1, 1]                   --
|    |    └─Linear: 3-38                 [-1, 1]                   2
|    |    └─ReLU: 3-39                   [-1, 1]                   --
|    |    └─Linear: 3-40                 [-1, 1]                   2
|    └─Head: 2-19                        [-1, 1]                   --
|    |    └─Linear: 3-41                 [-1, 1]                   2
|    |    └─ReLU: 3-42                   [-1, 1]                   --
|    |    └─Linear: 3-43                 [-1, 1]                   2
|    └─Head: 2-20                        [-1, 1]                   --
|    |    └─Linear: 3-44                 [-1, 1]                   2
|    |    └─ReLU: 3-45                   [-1, 1]                   --
|    |    └─Linear: 3-46                 [-1, 1]                   2
|    └─Head: 2-21                        [-1, 1]                   --
|    |    └─Linear: 3-47                 [-1, 1]                   2
|    |    └─ReLU: 3-48                   [-1, 1]                   --
|    |    └─Linear: 3-49                 [-1, 1]                   2
|    └─Head: 2-22                        [-1, 1]                   --
|    |    └─Linear: 3-50                 [-1, 1]                   2
|    |    └─ReLU: 3-51                   [-1, 1]                   --
|    |    └─Linear: 3-52                 [-1, 1]                   2
|    └─Head: 2-23                        [-1, 1]                   --
|    |    └─Linear: 3-53                 [-1, 1]                   2
|    |    └─ReLU: 3-54                   [-1, 1]                   --
|    |    └─Linear: 3-55                 [-1, 1]                   2
|    └─Head: 2-24                        [-1, 1]                   --
|    |    └─Linear: 3-56                 [-1, 1]                   2
|    |    └─ReLU: 3-57                   [-1, 1]                   --
|    |    └─Linear: 3-58                 [-1, 1]                   2
==========================================================================================
Total params: 21,285,638
Trainable params: 21,285,638
Non-trainable params: 0
Total mult-adds (G): 1.23
==========================================================================================
Input size (MB): 0.06
Forward/backward pass size (MB): 17.75
Params size (MB): 81.20
Estimated Total Size (MB): 99.01
==========================================================================================
<bound method Module.parameters of DataParallel(
  (module): Multihead_Resnet(
    (base_model): ResNet(
      (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer2): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer3): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (4): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (5): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
      (fc): Linear(in_features=512, out_features=14, bias=True)
    )
    (relu): ReLU(inplace=True)
    (heads): ModuleList(
      (0): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
      (1): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
      (2): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
      (3): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
      (4): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
      (5): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
      (6): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
      (7): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
      (8): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
      (9): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
      (10): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
      (11): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
      (12): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
      (13): Head(
        (fc1): Linear(in_features=1, out_features=1, bias=True)
        (relu): ReLU(inplace=True)
        (fc2): Linear(in_features=1, out_features=1, bias=True)
      )
    )
  )
)>

The number of parameters is not much. It is unclear why I get this error. Any help would be appreciated.

The problem was solved by changing the forward function to:

def forward(self, x):
        x = self.base_model(x)
        x = self.relu(x)
        a = []
        for i in range(self.num_classes):
            a.append(self.heads[i](x[:, i*self.head_elements:i*self.head_elements+self.head_elements]).squeeze())
        
        return torch.stack(a, dim=1)

but I still don’t know why the former approach resulted in OOM error. The only difference that is made is that I store the output tensors in a python list instead of a new tensor. Is declaring a tensor in forward function or indexing it causing the problem?