CUDA Invalid Configuration Error on GPU Only

Hi everyone, long-time reader, first-time poster!

I seem to be running into an odd bug when training my model. I am trying to train an embedding model using the Resnet18 architecture, which I have essentially cut off the last linear layer of (I can post the architecture if needed). I have written my own loss function which implements the “Batch-Hard” loss found in this paper. The training loop (standard zero_grad, output, loss, backward, step) works fine when running on CPU, but throws the following error when running on GPU:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-11-8cf2cfddbcc6> in <module>()
      3     output = learner.model(imgs)
      4     loss = criterion(output, labels)
----> 5     loss.backward()
      6     optimizer.step()
      7     break

.../anaconda/anaconda3-2019a/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    105                 products. Defaults to ``False``.
    106         """
--> 107         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    108 
    109     def register_hook(self, hook):

.../anaconda/anaconda3-2019a/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     91     Variable._execution_engine.run_backward(
     92         tensors, grad_tensors, retain_graph, create_graph,
---> 93         allow_unreachable=True)  # allow_unreachable flag
     94 
     95 

RuntimeError: CUDA error: invalid configuration argument

Does anyone know why a backwards pass would work on a CPU but not a GPU? I can confirm that my GPU works generally, I have trained other models with other custom loss functions.

For reference, this is the implementation of the “Batch-Hard” loss that I’ve written:

class BatchHardLoss(torch.nn.modules.loss._Loss):
    def __init__(self, margin, k=5, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.margin = margin
        self._k = k

    def forward(self, x, targets):
        distances = torch.cdist(x, x)  # Get the distances
        losses = []
        for t in torch.unique(targets):
            pos_idxs = (targets == t).nonzero().reshape(-1)  # Positive indices
            neg_idxs = (targets != t).nonzero().reshape(-1)  # Negative indices

            choice = np.random.choice(
                len(pos_idxs), min(self._k, len(pos_idxs))
            ).tolist()
            anchors = pos_idxs[choice]  # Chosen anchors

            # Only compute loss if we can do it for the full batch, for simplicity
            loss = 0.0
            if len(anchors) > 1 and len(neg_idxs) > 0:
                pos_distances = torch.index_select(
                    torch.index_select(distances, 0, anchors), 1, pos_idxs
                )
                pos_loss = pos_distances.max(dim=1)[0].sum()

                neg_distances = torch.index_select(
                    torch.index_select(distances, 0, anchors), 1, neg_idxs
                )
                neg_loss = neg_distances.min(dim=1)[0].sum()

                marg_loss = self.margin * len(anchors)

                losses.append(marg_loss + pos_loss - neg_loss)
            else:
                losses.append(torch.tensor(0.0, device=x.device, requires_grad=True))
        return torch.stack(losses).sum()

If anyone sees any glaring errors or has any insight, I’d really appreciate it.

Thanks!

P.S. I have seen this post but can’t seem to find a similar issue when debugging.

I don’t see anything obviously wrong in your code.
Could you print the output and labels tensors before calling loss.backward()?
Maybe we could isolate this issue to these values and reproduce the error.

If these tensors do not yield the same error, could you additionally also store the state_dict of your model and the current imgs?

1 Like

Thanks for replying @ptrblck! While trying to reproduce the error at home (I was at work when I first got the error) I started with a smaller batch size (I was using 1024), and lo and behold it worked! I did some investigating, and apparently 1023 works, but any batch_size >= 1024 causes that error. It seems to me that it’s a GPU memory issue, but I don’t know why I’m not getting the traditional RuntimeError: CUDA out of memory. I’ll double-check that this is the case on my work setup as well, and will reply back tomorrow.

3 Likes

That does seem to have been the problem - still not sure why the error wasn’t the usual out of memory, but glad it’s fixed!

1 Like

Good to hear it’s fixed now!
That’s indeed a bit weird, but sometimes OOM errors are masked by e.g. cudnn errors.
Haven’t seen this invalid config error yet.

1 Like

I’m receiving the same error and can’t seem to debug it. I don’t think it has to do with the memory of my GPU either - the computations aren’t that large and it still fails with a batch size of 1. When I add an extra component to my loss function I get the same error thrown on the loss.backward() step. Here is an excerpt from my loss function

def forward(self, anchor, positive, negative, model, size_average=True):
    #regular triplet loss. This works on GPU and CPU
    distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
    distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
    losses = F.relu(distance_positive - distance_negative + self.margin)

    #the additional component that causes the error. This will run on CPU but fails on GPU
    anchor_dists = torch.cdist(model.embedding_net.anchor_net.anchors, model.embedding_net.anchor_net.anchors)
    t = (self.beta * F.relu(self.rho - anchor_dists))
    regularization = t.sum() - torch.diag(t).sum()

    return regularization + losses.mean() if size_average else losses.sum()

Here is my model:

TripletNet(
  (embedding_net): EmbeddingNet(
    (anchor_net): AnchorNet(anchors torch.Size([128, 192]), biases torch.Size([128]))
    (embedding): Sequential(
      (0): AnchorNet(anchors torch.Size([128, 192]), biases torch.Size([128]))
      (1): Tanh()
    )
  )
)

Which is much smaller than the memory capacity of my GPU (8GB).

EDIT: Monitoring the GPU memory usage shows that I’m well under the memory limit when it crashes.

20%20pm

Any ideas @ptrblck ? Thanks!

Could you post a code snippet to reproduce this issue so that we could have a look?

Hi @ptrblck, I’d be happy to. Here is an excerpt from the file Runner.ipynb that generates the error:

cuda = torch.cuda.is_available()
#define datasets
triplet_train_dataset = TripletAudio(True, K, MAX_CLOSE_NEG, P_STRONG_NEG)
triplet_test_dataset = TripletAudio(False, K, MAX_CLOSE_NEG, P_STRONG_NEG)
triplet_train_loader = torch.utils.data.DataLoader(triplet_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
triplet_test_loader = torch.utils.data.DataLoader(triplet_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
#define model 
anchor_net = AnchorNet(triplet_train_dataset.get_dataset(), INPUT_D, OUTPUT_D)
embedding_net = EmbeddingNet(anchor_net)
model = TripletNet(embedding_net)
if cuda:
    model.cuda()
loss_fn = TripletLoss(MARGIN, RHO, BETA)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1, last_epoch=-1)
#run the model
train_loss, val_loss = fit(triplet_train_loader, triplet_test_loader, model, loss_fn, optimizer, scheduler, N_EPOCHS, cuda, LOG_INT)
# ^^^^^^ parent function call that generates the error  ^^^^^ 

And the error:

/opt/anaconda3/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:82: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule.See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-bda3ed742c36> in <module>
     16 scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1, last_epoch=-1)
     17 #run the model
---> 18 train_loss, val_loss = fit(triplet_train_loader, triplet_test_loader, model, loss_fn, optimizer, scheduler, N_EPOCHS, cuda, LOG_INT)

~/thesis/trainer.py in fit(train_loader, val_loader, model, loss_fn, optimizer, scheduler, n_epochs, cuda, log_interval, metrics, start_epoch)
     24         scheduler.step()
     25         # Train stage
---> 26         train_loss, metrics, writer_train_index = train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics, writer, writer_train_index)
     27 
     28         message = 'Epoch: {}/{}. Train set: Average loss: {:.4f}'.format(epoch + 1, n_epochs, train_loss)

~/thesis/trainer.py in train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics, writer, writer_train_index)
     80         losses.append(loss.item())
     81         total_loss += loss.item()
---> 82         loss.backward()
     83         optimizer.step()
     84 

/opt/anaconda3/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    116                 products. Defaults to ``False``.
    117         """
--> 118         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    119 
    120     def register_hook(self, hook):

/opt/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     91     Variable._execution_engine.run_backward(
     92         tensors, grad_tensors, retain_graph, create_graph,
---> 93         allow_unreachable=True)  # allow_unreachable flag
     94 
     95 

RuntimeError: CUDA error: invalid configuration argument

The fit function in trainer.py essentially just calls this sub-function:

def train_epoch(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics, writer, writer_train_index):
    for metric in metrics:
        metric.reset()

    model.train()
    losses = []
    total_loss = 0

    for batch_idx, (data, target, index) in enumerate(train_loader):
        target = target if len(target) > 0 else None
        if not type(data) in (tuple, list):
            data = (data,)
        if cuda:
            data = tuple(d.cuda() for d in data)
            if target is not None:
                target = target.cuda()

        optimizer.zero_grad()
        outputs = model(*data,)
        if type(outputs) not in (tuple, list):
            outputs = (outputs,)

        loss_inputs = outputs
        if target is not None:
            target = (target,)
            loss_inputs += target

        loss_outputs = loss_fn(*loss_inputs, model)

        loss = loss_outputs[0] if type(loss_outputs) in (tuple, list) else loss_outputs
        losses.append(loss.item())
        total_loss += loss.item()
        loss.backward() # this is the line that generates the error 
        optimizer.step()

And the loss function is supplied above. The full files (Runner.ipynb, trainer.py, losses.py) can be found here if desired.

Thanks for your help by the way! This has had me very, very stuck. If there is anything else that could help, please let me know

Thanks for the code!
It looks like the error is thrown in the backward call, so we could probably just raise the error using some dummy input and target data without creating the real dataset.
Could you post the shapes and types of data and target?

target is actually None (the logic is generalizable and happens to not use target) and data is an array containing three elements (a triplet), each of which is of size torch.Size([128, 192, 1]) and of type torch.FloatTensor

I still can’t run the code, as AnchorNet relies on some data.
Could you post some dummy input shapes to train_data (or an executable code snippet)?

Hi @ptrblck, here is a runnable colab document with everything stripped but the bare requirements to produce the error. The data loaders contain randomly generated tensors of the appropriate dimension. I hope this helps.

The same problem for torch.cdist function.

I also have the same error message. BTW, I use the torch.pdist either, which seems the same problem as @Naruto-Sasuke.
The CPU works fine, but GPU will have this error.

torch version: 1.4.0
CUDA: 10.1

Could you update to the latest stable PyTorch version (1.5.0), as we’ve fixed some issues in cdist and pdist regarding these invalid configurations?

@ptrblck thanks for the reply.
I’ve updated the pytorch from 1.4.0 to 1.5.0. Work perfectly.
Thanks again.

Same issue here with pytorch 1.5.0
Works on batch size = 512, crashes on batch size = 1024 reporting CUDA Invalid Config Error related to cdist