[SOLVED] CUDA out of memory even while using DataParallel and reducing batch size

I am trying to train a resnet18 model on CUB birds dataset with a batch size of 16 across 4 GPUs using data parallel. My resnet code adapted from here is as follows:

'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=200):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(2048, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        # print(out.shape)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])

My main trianing and test driver is as follows:

criterion = nn.CrossEntropyLoss()

logger = Logger(logs_path)

model = ResNet34()
model = nn.DataParallel(model)
model = model.cuda()
optimizer = optim.Adam(params=model.parameters(), lr=1e-4)

for epoch in range(1000):
    train_loss, train_acc = train(epoch)
    test_loss, test_acc = test(epoch)
    print("%f Epoch; %f Train Loss, %f Test Loss, %f Train Acc, %f Test Acc" % (epoch, train_loss, test_loss,
                                                                                    train_acc, test_acc))

My train and test loops are as follows:

def train(epoch):
    model.train()
    train_loss = 0.0
    train_acc = 0.0
    count = 0.0
    for id, sample in enumerate(train_loader):
        count += 1
        image, label = sample['image'], sample['label']
        if torch.cuda.is_available():
            image = image.cuda()
            label = label.cuda()

        optimizer.zero_grad()
        predictions = model(image)
        predictions = predictions.view(-1, 200)
        loss = criterion(predictions, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predictions = predictions.detach().max(1)
        correct = predictions.eq(label.detach()).sum().item()
        acc = 100.0 * correct / image.size(0)
        train_acc += acc
    # print(train_loss, count)
    train_loss = train_loss / count
    train_acc = train_acc / count
    info = {'train_loss': train_loss, 'train_acc': train_acc}
    for k, v in info.items():
        logger.scalar_summary(k, v, epoch+1)

    return train_loss, train_acc


def test(epoch):
    model.eval()
    test_loss = 0.0
    test_acc = 0.0
    count = 0.0
    for id, sample in enumerate(test_loader):
        count += 1
        image, label = sample['image'], sample['label']
        if torch.cuda.is_available():
            image = image.cuda()
            label = label.cuda()

        predictions = model(image)
        predictions = predictions.view(-1, 200)
        loss = criterion(predictions, label)
        test_loss += loss
        _, predictions = predictions.detach().max(1)
        correct = predictions.eq(label.detach()).sum().item()
        acc = 100.0 * correct / image.size(0)
        test_acc += acc
    test_loss = test_loss / count
    test_acc = test_acc / count
    info = {'test_loss': test_loss, 'test_acc': test_acc}
    for k, v in info.items():
        logger.scalar_summary(k, v, epoch + 1)

    return test_loss, test_acc

I get the following traceback:

 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m checks.check                                                                                                     ✭ ◼ master
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcublas.so.8.0 locally
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcudnn.so.5 locally
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcufft.so.8.0 locally
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcuda.so.1 locally
I tensorflow/stream_executor/dso_loader.cc:135] successfully opened CUDA library libcurand.so.8.0 locally
Traceback (most recent call last):
  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/.../checks/check.py", line 190, in <module>
    test_loss, test_acc = test(epoch)
  File "/.../checks/check.py", line 138, in test
    predictions = model(image)
  File "/.../lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/.../lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 123, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/.../lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 133, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/.../lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 77, in parallel_apply
    raise output
  File "/.../python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 53, in _worker
    output = module(*input, **kwargs)
  File "/.../lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/.../networks_supervised/resnet.py", line 89, in forward
    out = self.layer2(out)
  File "/.../lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/.../lib/python3.6/site-packages/torch/nn/modules/container.py", line 91, in forward
    input = module(input)
  File "/.../lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/.../networks_supervised/resnet.py", line 33, in forward
    out = F.relu(out)
  File "/.../lib/python3.6/site-packages/torch/nn/functional.py", line 643, in relu
    return torch.relu(input)
RuntimeError: CUDA error: out of memory

Would love to get any advice on what I am doing wrong. Thanks :v:

Solved it, using

with torch.no_grad():

after the for loop in the test.