Cuda out of memory, due to data set size?

I have the weirdest Problem, my neural network makes Cuda run out of memory but can manage x epochs depending on the size of the dataset.
Example: data set size: 80, epochs before error: 5, data set size: 40, epochs before error: 10, data set size: 20, epochs before error: 19 (all with a batch_size of 4)
I create a data set and use the data loader to avoid loading all images at once, but somehow I must be doing something wrong, if I run with my full data (25k images), it won’t even start.
Can someone help shed light on this phenomenon?

def doYaThang(path,epochs,batch_size,load=""):

transform = transforms.Compose([transforms.ToTensor()])
imagelist = glob.glob(path+str("*.jpg"))
random.shuffle(imagelist)
validation_split = 0.3
validation_amount = int(len(imagelist)*validation_split)
validation_set = imagelist[:validation_amount]
training_set = imagelist[validation_amount:]
whales = WhaleDataset(imagelist=training_set,transform=transform)
vhales = WhaleDataset(imagelist=validation_set,transform=transform)
train_loader = torch.utils.data.DataLoader(whales, batch_size=batch_size,  num_workers=2,shuffle=True)
val_loader = torch.utils.data.DataLoader(vhales, batch_size=batch_size,  num_workers=2,shuffle=True)
CNN = net()
trainNet(CNN, batch_size=batch_size, n_epochs=epochs, learning_rate=0.0001,train_loader=train_loader,val_loader=val_loader)

def trainNet(net, batch_size, n_epochs, learning_rate,train_loader,val_loader):

net = net.float().cuda()
n_batches = len(train_loader)
optimizer= optim.Adam(net.parameters(), lr=learning_rate)
training_start_time = time.time()
training_losses = []
validation_losses = []

for epoch in range(n_epochs):
    for i, inputs in enumerate(train_loader):

        optimizer.zero_grad()
        outputs, mu, logvar = net(inputs)
        print(inputs.shape)
        print(outputs.shape)
        BCEKLD, BCE, KLD = loss_fn(outputs, inputs,mu,logvar)
        BCEKLD.backward()
        total_train_loss += BCEKLD

    total_val_loss = 0
    for i, vinputs in enumerate(val_loader):
        val_outputs,vmu,vlogvar = net(vinputs)
        vBCEKLD, vBCE, vKLD  = loss_fn(val_outputs, inputs,vmu,vlogvar)
        total_val_loss += vBCEKLD
    training_losses.append(total_train_loss) 
    validation_losses.append(total_val_loss)
    print("Eppch " +str(epoch))
    print("Training loss = {:.2f}".format(total_train_loss / len(train_loader)))   
    print("Validation loss = {:.2f}".format(total_val_loss / len(val_loader)))

The network in question is a VAE with only one conv+maxpool and one unpool+convtranspose layer as I’m debugging the out of memory error.

These lines of code will append to losses, which are still attached to the computation graph, and will thus increase the memory usage in each iteration:

training_losses.append(total_train_loss) 
validation_losses.append(total_val_loss)

To store the loss value for printing or debugging purposes, you should .detach() the tensor and usually you would also call .item() on it to get a Python float instead of a tensor:

training_losses.append(total_train_loss.detach().cpu().item())

You should also wrap the validation loop in a with torch.no_grad() block, which will avoid storing the intermediate tensors, which would be needed for backpropagation and will thus reduce the memory.

1 Like

Thanks a lot! It worked like a charm.

1 Like