Out of memory error -- PyTorch allocates almost all GPU memory

I am trying to train a neural network with a PyTorch implementation of EfficientNetB5 on a Windows 11 machine with an RTX 4080 GPU, which has 16 GB of memory. The image dataset has 3 classes, with 12,500 training images (456 x 456 pixels) for each class, for a total occupied disk space of 12.9 GB. I use torchvision.datasets.ImageFolder to create training and validation sets from the training images. Using a batch size of 1 to be most conservative, I nonetheless get this error:

OutOfMemoryError: CUDA out of memory. Tried to allocate 96.00 MiB. GPU 0 has a total capacty of 15.99 GiB of which 0 bytes is free. Of the allocated memory 14.94 GiB is allocated by PyTorch, and 188.75 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Confirmed with nvidia-smi that there are no other processes using the GPU (see below). Also, torch.cuda.set_per_process_memory_fraction() is not set anywhere.

I don’t understand why PyTorch is making so much memory unavailable – seemingly taking just enough to trigger the memory error.

Thanks in advance.

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 528.49       Driver Version: 528.49       CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0  On |                  N/A |
| 30%   29C    P8    13W / 288W |  16045MiB / 16376MiB |     11%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      4448    C+G   ...01.200\msedgewebview2.exe    N/A      |
|    0   N/A  N/A      8764    C+G   C:\Windows\explorer.exe         N/A      |
|    0   N/A  N/A     10440    C+G   ...n1h2txyewy\SearchHost.exe    N/A      |
|    0   N/A  N/A     10472    C+G   ...artMenuExperienceHost.exe    N/A      |
|    0   N/A  N/A     11792    C+G   ...y\ShellExperienceHost.exe    N/A      |
|    0   N/A  N/A     12344    C+G   ...2gh52qy24etm\Nahimic3.exe    N/A      |
|    0   N/A  N/A     12932    C+G   ...cw5n1h2txyewy\LockApp.exe    N/A      |
|    0   N/A  N/A     14796    C+G   ...e\PhoneExperienceHost.exe    N/A      |
|    0   N/A  N/A     15404    C+G   ...perience\NVIDIA Share.exe    N/A      |
|    0   N/A  N/A     16256    C+G   ...oft\OneDrive\OneDrive.exe    N/A      |
|    0   N/A  N/A     16428    C+G   ...me\Application\chrome.exe    N/A      |
|    0   N/A  N/A     17872    C+G   ...pingTool\SnippingTool.exe    N/A      |
|    0   N/A  N/A     18044    C+G   ...Vantage\LenovoVantage.exe    N/A      |
|    0   N/A  N/A     19068    C+G   ...8bbwe\WindowsTerminal.exe    N/A      |
|    0   N/A  N/A     20320    C+G   ...ge\Application\msedge.exe    N/A      |
|    0   N/A  N/A     22388    C+G   ...ark-0.29.4.0\gpushark.exe    N/A      |
|    0   N/A  N/A     24528    C+G   ...01.200\msedgewebview2.exe    N/A      |
|    0   N/A  N/A     25696    C+G   ...01.203\msedgewebview2.exe    N/A      |
|    0   N/A  N/A     26592    C+G   ...txyewy\CHXSmartScreen.exe    N/A      |
|    0   N/A  N/A     27040    C+G   ...lPanel\SystemSettings.exe    N/A      |
|    0   N/A  N/A     29772      C   ...rior\anaconda3\python.exe    N/A      |
+-----------------------------------------------------------------------------+

Hard to say without a minimal repro

  • I believe efficientnetb5 should be about 100MB
  • bs=1 with a 456x456 image should be neglible

My guess is either the model is bigger than you expect because of some changes you made or you allocated all training images to GPU directly which will OOM

Thank you – no changes to the model, and I don’t think I’m allocating all training images at once to the GPU. But I’m new enough to PyTorch that I’ll post my plain-vanilla training loop:

def train(model, criterion, optimizer, trainloader, validationloader, epochs):
    best_val_loss = float("inf")
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for i, (inputs, targets) in enumerate(tqdm(trainloader)):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * inputs.size(0)
        train_loss /= len(trainloader.dataset)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for i, (inputs, targets) in enumerate(tqdm(validationloader)):
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)
            val_loss /= len(validationloader.dataset)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "efficientnet.pth")