GPU memory issue on UNET


(Antoine) #1

Hi !

I’m having trouble with CNN : I implemented an encoder-decoder (UNET) for image segmentation purpose. I made a custom dataset wich return an image with its associated ground truth loaded (thanks to getitem() method).

The problem is I cannot load more than 200 images (size : 480x640x3) otherwise my GPU is OOM :confused:.
Here is a screenshot of what’s happening inside my RAM during training (each drop correspond to the end of an epoch and the beginning of another)
capture

How can I increase the number of image to train on ?

Thank you very much :smiley:


(Justus Schock) #2

You should not load all images of an epoch into GPU RAM, but only the ones of the current batch (and keep the rest in CPU RAM or load them from disk just in time).


(Antoine) #3

Thank you for your answer,

I am not sure of the following code behaviour :

def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0
    for i,data in enumerate(train_loader):
        image = data['image']
        mask = data['mask']
        
        optimizer.zero_grad()
        output = model(image)

        loss = criterion(output, mask)
        loss.backward()

        optimizer.step()

(train is called every epoch)

I tought it was loading image an mask by pair and after computing the loss the pair was unloaded from GPU is that right ?

By the way my batch size is 2


(Justus Schock) #4

Not necessarily. Depends on your train_loader. I’d recommend you, to load them to CPU-RAM inside your loader and just push them to GPU inside the train loop.


(Antoine) #5

I moved the GPU loading methods from my Dataset class to the train method but nothing changed.
This is how train_loader is made :

Note that when I try to increase batch size from 1 to 2 or more I got this error (maybe it can help to find the problem) :

The error happens at the beginning of the first convolutionnal layer …


(Justus Schock) #6

That seems to be the limit of your GPU. So, you can train on single-image batches, and maybe with mixed precision (e.g. via APEX) or checkpointing you might be able to increase that, but besides that, the only choice would be to reduce the image size.


(Antoine) #7

Ok that’s what I thought :confused:
I will try this, thank you for your help !