I am running a multi-label classification task with 14 classes.
When I use a resnet34 with 128 batch size and 128x128 input image everything is ok on a 8GB gpu. The output size is [batch_size, 14] and I am using F.binary_cross_entropy_with_logits loss for this multilabel classification task.
When I change the architecture to be a multi-head network I face runtime error: CUDA out of memory. The multihead architecture is using resnet34 as the base and MLPs as heads. I even changed the MLP layers to be of size 1 and I changed the batch_size from 128 to 16 but still getting memory error.
Here’s the code:
import torch
import torch.nn as nn
from .resnet import *
class Head(nn.Module):
def __init__(self, head_elements, num_classes=1):
super().__init__()
self.fc1 = nn.Linear(head_elements, 1)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(1, num_classes)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
class Multihead_Resnet(nn.Module):
def __init__(self, num_classes, device, head_elements=1, base_model='resnet34'):
super().__init__()
self.device = device
self.num_classes = num_classes
self.head_elements = head_elements
self.base_model_out_size = num_classes * head_elements
if base_model == 'resnet34':
self.base_model = resnet34(num_classes=self.base_model_out_size)
self.relu = nn.ReLU(inplace=True)
self.heads = nn.ModuleList([Head(head_elements=head_elements) for i in range(num_classes)])
def forward(self, x):
x = self.base_model(x)
x = self.relu(x)
y = torch.zeros(x.shape[0], self.num_classes, requires_grad=True, device=self.device)
for i in range(self.num_classes):
y[:, i] = self.heads[i](x[:, i*self.head_elements:i*self.head_elements+self.head_elements]).squeeze()
return y
I printed out the model and a summary of it:
loading dataset...
len(train_loader.dataset): 86524
train_dataset_array.shape: torch.Size([16, 1, 128, 128])
building model...
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
├─ResNet: 1-1 [-1, 14] --
| └─Conv2d: 2-1 [-1, 64, 64, 64] 3,136
| └─BatchNorm2d: 2-2 [-1, 64, 64, 64] 128
| └─ReLU: 2-3 [-1, 64, 64, 64] --
| └─MaxPool2d: 2-4 [-1, 64, 32, 32] --
| └─Sequential: 2-5 [-1, 64, 32, 32] --
| | └─BasicBlock: 3-1 [-1, 64, 32, 32] 73,984
| | └─BasicBlock: 3-2 [-1, 64, 32, 32] 73,984
| | └─BasicBlock: 3-3 [-1, 64, 32, 32] 73,984
| └─Sequential: 2-6 [-1, 128, 16, 16] --
| | └─BasicBlock: 3-4 [-1, 128, 16, 16] 230,144
| | └─BasicBlock: 3-5 [-1, 128, 16, 16] 295,424
| | └─BasicBlock: 3-6 [-1, 128, 16, 16] 295,424
| | └─BasicBlock: 3-7 [-1, 128, 16, 16] 295,424
| └─Sequential: 2-7 [-1, 256, 8, 8] --
| | └─BasicBlock: 3-8 [-1, 256, 8, 8] 919,040
| | └─BasicBlock: 3-9 [-1, 256, 8, 8] 1,180,672
| | └─BasicBlock: 3-10 [-1, 256, 8, 8] 1,180,672
| | └─BasicBlock: 3-11 [-1, 256, 8, 8] 1,180,672
| | └─BasicBlock: 3-12 [-1, 256, 8, 8] 1,180,672
| | └─BasicBlock: 3-13 [-1, 256, 8, 8] 1,180,672
| └─Sequential: 2-8 [-1, 512, 4, 4] --
| | └─BasicBlock: 3-14 [-1, 512, 4, 4] 3,673,088
| | └─BasicBlock: 3-15 [-1, 512, 4, 4] 4,720,640
| | └─BasicBlock: 3-16 [-1, 512, 4, 4] 4,720,640
| └─AdaptiveAvgPool2d: 2-9 [-1, 512, 1, 1] --
| └─Linear: 2-10 [-1, 14] 7,182
├─ReLU: 1-2 [-1, 14] --
├─ModuleList: 1 [] --
| └─Head: 2-11 [-1, 1] --
| | └─Linear: 3-17 [-1, 1] 2
| | └─ReLU: 3-18 [-1, 1] --
| | └─Linear: 3-19 [-1, 1] 2
| └─Head: 2-12 [-1, 1] --
| | └─Linear: 3-20 [-1, 1] 2
| | └─ReLU: 3-21 [-1, 1] --
| | └─Linear: 3-22 [-1, 1] 2
| └─Head: 2-13 [-1, 1] --
| | └─Linear: 3-23 [-1, 1] 2
| | └─ReLU: 3-24 [-1, 1] --
| | └─Linear: 3-25 [-1, 1] 2
| └─Head: 2-14 [-1, 1] --
| | └─Linear: 3-26 [-1, 1] 2
| | └─ReLU: 3-27 [-1, 1] --
| | └─Linear: 3-28 [-1, 1] 2
| └─Head: 2-15 [-1, 1] --
| | └─Linear: 3-29 [-1, 1] 2
| | └─ReLU: 3-30 [-1, 1] --
| | └─Linear: 3-31 [-1, 1] 2
| └─Head: 2-16 [-1, 1] --
| | └─Linear: 3-32 [-1, 1] 2
| | └─ReLU: 3-33 [-1, 1] --
| | └─Linear: 3-34 [-1, 1] 2
| └─Head: 2-17 [-1, 1] --
| | └─Linear: 3-35 [-1, 1] 2
| | └─ReLU: 3-36 [-1, 1] --
| | └─Linear: 3-37 [-1, 1] 2
| └─Head: 2-18 [-1, 1] --
| | └─Linear: 3-38 [-1, 1] 2
| | └─ReLU: 3-39 [-1, 1] --
| | └─Linear: 3-40 [-1, 1] 2
| └─Head: 2-19 [-1, 1] --
| | └─Linear: 3-41 [-1, 1] 2
| | └─ReLU: 3-42 [-1, 1] --
| | └─Linear: 3-43 [-1, 1] 2
| └─Head: 2-20 [-1, 1] --
| | └─Linear: 3-44 [-1, 1] 2
| | └─ReLU: 3-45 [-1, 1] --
| | └─Linear: 3-46 [-1, 1] 2
| └─Head: 2-21 [-1, 1] --
| | └─Linear: 3-47 [-1, 1] 2
| | └─ReLU: 3-48 [-1, 1] --
| | └─Linear: 3-49 [-1, 1] 2
| └─Head: 2-22 [-1, 1] --
| | └─Linear: 3-50 [-1, 1] 2
| | └─ReLU: 3-51 [-1, 1] --
| | └─Linear: 3-52 [-1, 1] 2
| └─Head: 2-23 [-1, 1] --
| | └─Linear: 3-53 [-1, 1] 2
| | └─ReLU: 3-54 [-1, 1] --
| | └─Linear: 3-55 [-1, 1] 2
| └─Head: 2-24 [-1, 1] --
| | └─Linear: 3-56 [-1, 1] 2
| | └─ReLU: 3-57 [-1, 1] --
| | └─Linear: 3-58 [-1, 1] 2
==========================================================================================
Total params: 21,285,638
Trainable params: 21,285,638
Non-trainable params: 0
Total mult-adds (G): 1.23
==========================================================================================
Input size (MB): 0.06
Forward/backward pass size (MB): 17.75
Params size (MB): 81.20
Estimated Total Size (MB): 99.01
==========================================================================================
<bound method Module.parameters of DataParallel(
(module): Multihead_Resnet(
(base_model): ResNet(
(conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(4): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(5): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=14, bias=True)
)
(relu): ReLU(inplace=True)
(heads): ModuleList(
(0): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
(1): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
(2): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
(3): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
(4): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
(5): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
(6): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
(7): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
(8): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
(9): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
(10): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
(11): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
(12): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
(13): Head(
(fc1): Linear(in_features=1, out_features=1, bias=True)
(relu): ReLU(inplace=True)
(fc2): Linear(in_features=1, out_features=1, bias=True)
)
)
)
)>
The number of parameters is not much. It is unclear why I get this error. Any help would be appreciated.