CUDA out of memory halfway during training. Not sure whether its due to batch norm

File "", line 226, in 
File "", line 94, in main
train_phase(model, optimizer, scheduler, epoch, data_loaders, data_size)
File "", line 138, in train_phase
anc_embed, pos_embed, neg_embed = model(anc_img), model(pos_img), model(neg_img)
File "/home/kelvinheng/anaconda3/envs/proj1/lib/python3.7/site-packages/torch/nn/modules/", line 489, in  **call**
result = self.forward(*input, **kwargs)
File "/home/kelvinheng/proj1_code/", line 37, in forward
embed = self.resnet50(images)
File "/home/kelvinheng/anaconda3/envs/proj1/lib/python3.7/site-packages/torch/nn/modules/", line 489, in  **call**
result = self.forward(*input, **kwargs)
File "/home/kelvinheng/anaconda3/envs/proj1/lib/python3.7/site-packages/torchvision/models/", line 157, in forward
x = self.layer3(x)
File "/home/kelvinheng/anaconda3/envs/proj1/lib/python3.7/site-packages/torch/nn/modules/", line 489, in  **call**
result = self.forward(*input, **kwargs)
File "/home/kelvinheng/anaconda3/envs/proj1/lib/python3.7/site-packages/torch/nn/modules/", line 92, in forward
input = module(input)
File "/home/kelvinheng/anaconda3/envs/proj1/lib/python3.7/site-packages/torch/nn/modules/", line 489, in  **call**
result = self.forward(*input, **kwargs)
File "/home/kelvinheng/anaconda3/envs/proj1/lib/python3.7/site-packages/torchvision/models/", line 84, in forward
out = self.bn2(out)
File "/home/kelvinheng/anaconda3/envs/proj1/lib/python3.7/site-packages/torch/nn/modules/", line 489, in  **call**
result = self.forward(*input, **kwargs)
File "/home/kelvinheng/anaconda3/envs/proj1/lib/python3.7/site-packages/torch/nn/modules/", line 76, in forward
exponential_average_factor, self.eps)
File "/home/kelvinheng/anaconda3/envs/proj1/lib/python3.7/site-packages/torch/nn/", line 1623, in batch_norm
training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: CUDA out of memory. Tried to allocate 5.75 MiB (GPU 0; 10.73 GiB total capacity; 9.64 GiB already allocated; 3.25 MiB free; 121.92 MiB cached)

The above is my error. The training runs for 60 epochs before CUDA runs out of memory. Not sure whether it is due to batchnorm. If i decrease my batch size, i can run for a few more epochs before CUDA is out of memory. Can anyone help?

I am using RTX 2080TI and pytorch 1.0, python 3.7, CUDA 10.0. It is just a basic resnet50 from torchvision.models and i change the last fc layer to output 256 embeddings and train with triplet loss.


You might have a memory leak if your code runs fine for a few epochs and then runs out of memory.
Could you run it again and have a look at nvidia-smi?
Do you see an increasing memory usage?

If so, would it be possible to post a (small) executable script so that we could have a look at your code?
Make sure to avoid storing the computation graph, e.g. by doing something like:

loss = criterion(output, target)
# instead use
losses.append(loss.item())  # detaches the computation graph

PS: Iā€™ve formatted your post for better readability. If you would like to add code or terminal outputs you can wrap it in three backticks ``` :wink:

Hi, Thank you very much for the reply.

I ran it a few times and did not observe a memory increase. It has been stable at around 9GB out of 11GB memory. The only thing is that it just suddenly runs out of memory halfway.

Please see below for the code snippet

This is the train script.

"""define some global variables"""
args = parser.parse_args()
args.train_folder = datetime.strftime(, '%Y%m%d-%H%M%S')
device  = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Resnet50(embedding_dim = args.emb_dim, pretrained=False).to(device)
optimizer = optim.Adam(model.parameters(), lr = args.learning_rate)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 50, gamma = 0.1)

"""triplet loss fn is here because it  will cause out of memory by building graphs repeatedly during loops"""
#p=2 for eucliedean loss, reduction = mean to calculate avg triplet loss
triplet_loss_fn = torch.nn.TripletMarginLoss(margin=args.margin, p=2, reduction = 'mean')

def main():

    #if reload model to continue training
    if args.model_weight_path:
        #load checkpoint to get model, optimizer and start_epoch from checkpoint
        checkpoint = torch.load(args.model_weight_path)
        args.start_epoch = checkpoint['epoch']
        args.start_epoch = 0

    for epoch in tqdm(range(args.start_epoch, args.num_epochs + args.start_epoch), desc = 'epochs'):

        print(80 * '=')
        print('Epoch [{}/{}]'.format(epoch, args.num_epochs + args.start_epoch - 1))

        #get data loaders
        data_loaders, data_size = get_dataloader(args.train_root_dir,     args.valid_root_dir,
                                                 args.train_csv_name,     args.valid_csv_name,
                                                 args.batch_size,         args.num_workers,
                                                 args.image_size,         args.cropped_image_size)
        print('data loaded for epoch {}'.format(epoch))

        train_phase(model, optimizer, scheduler, epoch, data_loaders, data_size)
        valid_phase(model, epoch, data_loaders, data_size)

def train_phase(model, optimizer, scheduler, epoch, dataloaders, data_size):

    """train phase"""
    phase = 'train'
    print('in phase train for epoch {}'.format(epoch))

    triplet_loss_sum  = 0.0


    for batch_idx, batch_sample in tqdm(enumerate(dataloaders['train']), desc ='batches'):

        batch_size = batch_sample['anc_img'].shape[0] #current batch size

        """plot some anchor images in the first batch to analyse """
#        for i in range(batch_size):
#            print('anc_img_path: ', batch_sample['anc_img_path'][i])
#            print('mean values: ', torch.mean(batch_sample['anc_img'][i]))
#            print('std values: ', torch.std(batch_sample['anc_img'][i]))
#            plt.imshow(np.transpose(batch_sample['anc_img'][i].numpy(), (1,2,0)))
#            plt.savefig('img{}.jpg'.format(i))
#        sys.exit()

        anc_img = batch_sample['anc_img'].to(device)
        pos_img = batch_sample['pos_img'].to(device)
        neg_img = batch_sample['neg_img'].to(device)
#        pos_class = batch_sample['pos_class']
#        neg_class = batch_sample['neg_class']
#        anc_img_path = batch_sample['anc_img_path']
#        pos_img_path = batch_sample['pos_img_path']
#        neg_img_path = batch_sample['neg_img_path']

#        for i in range(batch_size):
#            print('triplets are \n{}, \n{}, \n{}'.format(anc_img_path[i], pos_img_path[i], neg_img_path[i]))
#            print('classes are {} and {}'.format(pos_class[i], neg_class[i]))

        with torch.set_grad_enabled(True):

            # anc_embed, pos_embed and neg_embed are encoding(embedding) of image
            anc_embed, pos_embed, neg_embed = model(anc_img), model(pos_img), model(neg_img)

            #calculates the averaged triplet loss per batch
            batch_avg_triplet_loss = triplet_loss_fn(anc_embed, pos_embed, neg_embed)

            if batch_avg_triplet_loss.item() == 0:


            triplet_loss_sum += batch_avg_triplet_loss.item() * batch_size
            print('loss for epoch {}, batch {}: {}'.format(epoch, batch_idx, batch_avg_triplet_loss.item()))


    epoch_avg_triplet_loss = triplet_loss_sum / data_size[phase]
    print('{} set - Average Triplet Loss for epoch {}: {}'.format(phase, epoch, epoch_avg_triplet_loss))

    #make folder directory if it doesnt exist
    if not os.path.isdir(args.train_folder):
        with open('{}/args.txt'.format(args.train_folder), 'a+') as f:
            for key, value in iteritems(vars(args)):
                f.write('%s: %s\n' % (key, str(value)))

    with open('{}/train_loss.txt'.format(args.train_folder), 'a+') as f:
        f.write('{}: {}\n'.format(str(epoch), str(round(epoch_avg_triplet_loss, 4))))
    f.close(){'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()},
                '{}/checkpoint_epoch_{}.pth'.format(args.train_folder, epoch))

def valid_phase(model, epoch, dataloaders, data_size):

    print('in phase valid for epoch {}'.format(epoch))

    full_emb = []
    full_class = []
    full_img_path = []
    valid_size = data_size['valid']
    matches = 0

    for batch_idx, batch_sample in enumerate(dataloaders['valid']):

        img = batch_sample['img'].to(device)
        img_class = batch_sample['img_class'].to(device)
        img_path = batch_sample['img_path']

        with torch.set_grad_enabled(False):

            img_embed = model(img)
            full_img_path += img_path


    full_emb =, 0)
    full_class =, 0)

    pairwise_distance = pdist(full_emb, full_emb)
    pairwise_distance[range(valid_size), range(valid_size)] = 1000000 * torch.diag(pairwise_distance)
    closest_img_idx = torch.argmin(pairwise_distance, dim=1)

    for idx in range(valid_size):
        person_class = full_class[idx].item()
        closest_img_class = closest_img_idx[idx].item()
        print('full_emb: ', full_emb[idx,:10])
        print('full_img_path: ', full_img_path[idx])
        print('own class: ', person_class)
        print('closest img class: ', closest_img_class)
        if int(person_class) == int(closest_img_class):
            matches += 1

    accuracy = float(matches / valid_size)
    print('accuracy for epoch {} = {}/{} = {}'.format(epoch, matches, valid_size, accuracy))

    with open('{}/accuracy.txt'.format(args.train_folder), 'a+') as f:
        f.write('{}: {}\n'.format(str(epoch), str(round(accuracy, 4))))

if __name__ == '__main__':

This is the model

class Resnet50(nn.Module):
    def __init__(self, embedding_dim = 512, pretrained = False):
        super(Resnet50, self).__init__()
        self.embedding_dim = embedding_dim
        self.resnet50 = models.resnet50(pretrained=pretrained)
        self.linear = nn.Linear(self.resnet50.fc.in_features, embedding_dim)
        self.resnet50.fc = self.linear

    def init_weights(self):, 0.02)

    def l2_norm(self, input):
        input_size = input.size()
        buffer     = torch.pow(input, 2)
        normp      = torch.sum(buffer, 1).add_(1e-10)
        norm       = torch.sqrt(normp)
        _output    = torch.div(input, norm.view(-1, 1).expand_as(input))
        output     = _output.view(input_size)

        return output

    def forward(self, images):
        embed = self.resnet50(images)
        norm_emb = self.l2_norm(embed)
        return norm_emb

This is the loss

class TripletLoss(Function):

    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.pdist  = PairwiseDistance(p=2)

    def forward(self, anchor, positive, negative):
        pos_dist   = self.pdist.forward(anchor, positive)
        neg_dist   = self.pdist.forward(anchor, negative)

        hinge_dist = torch.clamp(self.margin + pos_dist - neg_dist, min = 0.0)
        loss       = torch.mean(hinge_dist)
        return loss

def pdist(emb1, emb2, eps=1e-6):
    """Compute the matrix of all squared pairwise distances.
    refer to
    emb1 : torch.Tensor or Variable
        The first sample, should be of shape ``(n_1, d)``.
    emb2 : torch.Tensor or Variable
        The second sample, should be of shape ``(n_2, d)``.
    norm : float
        The l_p norm to be used.
    torch.Tensor or Variable
        Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
        ``|| emb1[i, :] - emb2[j, :] ||_p``."""
    n_1, n_2 = emb1.size(0), emb2.size(0)
    norms_1 = torch.sum(emb1**2, dim=1, keepdim=True)
    norms_2 = torch.sum(emb2**2, dim=1, keepdim=True)
    norms = (norms_1.expand(n_1, n_2) +
         norms_2.transpose(0, 1).expand(n_1, n_2))
    distances_squared = norms - 2 *
    return torch.sqrt(eps + torch.abs(distances_squared))

Just to be entirely sure. If I am using RTX2080ti and would like a compatible pytorch, cuda and cudnn environment, does conda support that or do I have to install from source. It seems that there are some compatibility issues with RTX2080ti and pytorch conda across the discussion forums.

Thanks in advance :slight_smile:

1 Like

am having exactly the same issue with 2080ti, @Kelvin_Heng did you solve it? Thanks