No GPU utilization despite used GPU memory

I try training a 1D convolutional neural network, which has a bare size of about 183MB.
My data are wav files, which are of ± 70KB per file. I prepare them with DataLoaders and then want to train the model as I would do for other projects.

Now if I set the batch size to 64, about 4GB of my GPU memory is in use. My GPU utilization remains under 2% though, and I get the following error message:

C:\ProgramData\Anaconda3\envs\user\Lib\site-packages\torch\autograd\graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at …\aten\src\ATen\native\cudnn\Conv_v8.cpp:919.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass

The program proceeds with training, with about 0.6-1.0 seconds per batch, of which 0.2-0.3 seconds are dedicated to loading the data. It feels quite slow as I have to process about 5000 batches per iteration.

When I set my batch size to 256, I do not get the error as mentioned above but a memory error instead. Before I get this memory error though, GPU utilization is at 60% for a while, and afterwards (after the memory error) plunges to under 2% again (yet it keeps training the model somehow).

As I am not familiar with the first error message I mentioned, and I could not find a similar case, could someone explain to me whether it is related to the low GPU utilization despite used memory?
Thanks in advance.

You should try to get rid of the error first by isolating them or by posting a minimal and executable code snippet reproducing these issues in the latest PyTorch release.
To further narrow down which operations are the current bottleneck in your code, you could profile the workload with a visual profiler, such as Nsight Systems.

This is a reduced version of my code:

import torch  # 2.3.0+cu121
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np  # 1.26.0


# Defines building blocks for the model
class SRBlock2(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(SRBlock2, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=2)
        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.pool(out)
        return out

class LastConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(LastConv, self).__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size)

    def forward(self, x):
        out = self.conv(x)
        out = out.view(out.size(0), -1)
        return out

# "The model" (reduced enough to still reproduce the error)
class nEMGNet(nn.Module):
    def __init__(self):
        super(nEMGNet, self).__init__()
        self.model = self.net_B()

    def forward(self, x):
        return self.model(x)

    def net_B(self):
        model = nn.Sequential(
            SRBlock2(1, 64, kernel_size=11),
            LastConv(64, 64, 4407),
            nn.Linear(64, 16),
            nn.Linear(16, 2),
            nn.Softmax(dim=1)
        )
        return model

    def net_A(self):
        model = nn.Sequential(
            SRBlock2(1, 64, kernel_size=11),
            SRBlock2(64, 128, kernel_size=7),
            LastConv(128, 128, 1100),
            nn.Linear(128, 64),
            nn.Linear(64, 16),
            nn.Linear(16,2),
            nn.Softmax(dim=1)
        )
        return model


class RandomDataset(Dataset):
    def __init__(self, num_samples, num_features, num_classes):
        self.data = torch.randn(num_samples, 1, num_features)
        self.labels = torch.randint(0, num_classes, (num_samples,))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample, label = self.data[idx], self.labels[idx]
        return sample, label

# Initializes random dataset of similar shape as original
num_samples = 1000
num_features = 17640
num_classes = 2
batch_size = 64

dataset = RandomDataset(num_samples, num_features, num_classes)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

# Defines training parameters
lr = 1e-3
epochs = 500
patience = 10
best_loss = 1.0
no_improvement = 0

# Creates model, optimizer, and loss function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device used: {device}")
model = nEMGNet().to(device)
optimizer = torch.optim.NAdam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

# Training loop
train_losses = []
for epoch in range(epochs):
    # Training
    model.train()
    train_loss = 0.0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)

    train_loss /= len(train_loader.dataset)
    train_losses.append(train_loss)

When I run this, I get a similar error as mentioned before. Now the message is accompanied with many, many extra lines such as:

C:\ProgramData\Anaconda3\envs\user\Lib\site-packages\torch\autograd\graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ..\aten\src\ATen\native\cudnn\Conv_v8.cpp:919.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
C:\ProgramData\Anaconda3\envs\user\Lib\site-packages\torch\autograd\graph.py:744: UserWarning: Plan failed with a CuDNNError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED
Exception raised from run_conv_plan at ..\aten\src\ATen\native\cudnn\Conv_v8.cpp:375 (most recent call first):
00007FF9D6CA366200007FF9D6CA3600 c10.dll!c10::Error::Error [<unknown file> @ <unknown line number>]
00007FF8E854784600007FF8E8544EB0 torch_cuda.dll!at::native::_fft_r2c_cufft_out [<unknown file> @ <unknown line number>]
00007FF8E858257000007FF8E85518C0 torch_cuda.dll!at::native::cudnn_convolution_transpose [<unknown file> @ <unknown line number>]
00007FF8E858806600007FF8E85518C0 torch_cuda.dll!at::native::cudnn_convolution_transpose [<unknown file> @ <unknown line number>]
00007FF8E429EA9300007FF8E428CEA0 torch_python.dll!THPPointer<PyCodeObject>::THPPointer<PyCodeObject> [<unknown file> @ <unknown line number>]
00007FF920331EF500007FF920331410 torch_cpu.dll!torch::autograd::Engine::get_base_engine [<unknown file> @ <unknown line number>]
00007FFA1E7D268A00007FFA1E7D2630 ucrtbase.dll!o_exp [<unknown file> @ <unknown line number>]
00007FFA20337AC400007FFA20337AB0 KERNEL32.DLL!BaseThreadInitThunk [<unknown file> @ <unknown line number>]
00007FFA225AA8C100007FFA225AA8A0 ntdll.dll!RtlUserThreadStart [<unknown file> @ <unknown line number>]
 (Triggered internally at ..\aten\src\ATen\native\cudnn\Conv_v8.cpp:921.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "C:\Users\user\Documents\code\py_files\Error repro.py", line 103, in <module>
    loss.backward()
  File "C:\ProgramData\Anaconda3\envs\user\Lib\site-packages\torch\_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "C:\ProgramData\Anaconda3\envs\user\Lib\site-packages\torch\autograd\__init__.py", line 267, in backward
    _engine_run_backward(
  File "C:\ProgramData\Anaconda3\envs\user\Lib\site-packages\torch\autograd\graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: GET was unable to find an engine to execute this computation

(I cut back quite a bit on the lines starting with 00007…, but if necessary I can share them)
The original error, with which the program keeps running, is reproduced on my machine by setting self.model = self.net_A() in the initialization of the nEMGNet() class.

I hope this helps in determining what the error stems from.

Upon closer inspection it seems like some of the CUDA cores are actually used, but overall utilization is still very low (so either too light a workload or I/O problems).
Regarding the error, I also see that the CUDA version on the server is 10.1, with the cuDNN version likely being 8.0.5.
From what I could find is PyTorch 2.3 (which I use) stable on newer CUDA and cuDNN versions.
Now I am not sure to what extend newer PyTorch releases are backward compatible, but could this be the cause of the cuDNN error?

I cannot reproduce any issues using your code and see:

Device used: cuda
epoch 0, loss 0.7132615447044373
epoch 1, loss 0.7382615208625793
epoch 2, loss 0.7882614731788635
epoch 3, loss 0.6882615089416504
epoch 4, loss 0.9382616877555847
epoch 5, loss 0.7382615208625793
epoch 6, loss 0.638261616230011
epoch 7, loss 0.788261353969574
epoch 8, loss 0.8132615089416504
epoch 9, loss 0.8132616281509399
...

Yes, this matches my previous recommendation:

The PyTorch binaries ship with their own CUDA runtime dependencies.
If you mix these with your system libs, you might run into conflicts. As a debugging step you could remove the locally installed CUDA toolkit and cuDNN from the LD_LIBRARY_PATH (or the equivalent on Windows) to make sure PyTorch loads its own CUDA libs.