Greatly Reduced Validation Accuracy and Segfault When Training on GPU?

I’m getting very weird behavior when training a basic CNN text classification model with Cuda v11.1 on an AWS machine. I have tried this on both p2 and g3 instance types. When training on the CPU, the model does very well and achieves about 85% validation accuracy after ~15 epochs. However, using the exact same code with CUDA gives strange results. I get random segmentation faults during the training process. In addition, the validation accuracy barely changes between batches even though the loss does change, and it hovers around 65%. Here is my code

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from sklearn.utils import shuffle

from cnn import BasicConvModel
from data import (
    load_data,
    preprocess_text,
    create_vocabulary,
    get_valid_accuracy
)

torch.manual_seed(12)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('cpu')
print(device)

output_size = 1
embed_dim = 600
num_filters = 200
kernel_sizes = [2, 3, 4, 5]
batch_size = 16

X, y = load_data('../data/classification_sample.csv')

processed_text = preprocess_text(X)
vocab = create_vocabulary(processed_text)

model = BasicConvModel(vocab, embed_dim, num_filters, kernel_sizes, output_size)
if torch.cuda.is_available():
    model.cuda()

loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

processed_text, y = shuffle(processed_text, y, random_state=132)
max_idx = max([len(text) for text in processed_text])
text_idxs, pad_idx = [], vocab['<PAD>']
for text in processed_text:
    to_idx = [vocab[tok] for tok in text]
    for i in range(max_idx - len(text)):
        to_idx.append(pad_idx)
    text_idxs.append(to_idx)

full_tensor = torch.tensor(text_idxs, device=device)
y = torch.tensor(y, device=device)

num_epochs = 20

valid_idx = int(0.2 * full_tensor.size(0))
X, valid_X = full_tensor[valid_idx:, :], full_tensor[:valid_idx, :]
y, valid_y = y[valid_idx:], y[:valid_idx]

num_batches = X.size(0) // batch_size

print_every = 10

model.train()
for i in range(num_epochs):
    for j in range(num_batches + 1):
        start_idx, end_idx = j * batch_size, (j + 1) * batch_size
        batch_X, batch_y = X[start_idx:end_idx, :], y[start_idx:end_idx]
        preds = model(batch_X)

        loss = loss_fn(preds.squeeze(), batch_y.float())
        loss.backward()
        optimizer.step()
        model.zero_grad()

        if j % print_every == 0:
            acc = get_valid_accuracy(valid_X, valid_y, model)
            print(f"Epoch: {i}, Loss: {loss}, Validation Accuracy: {acc}")

And the CNN model

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class BasicConvModel(nn.Module):
    def __init__(self, vocab, embed_dim, num_filters, kernel_sizes,
                 output_size, use_dropout=True, dropout_prob=0.2):
        super().__init__()
        self.vocab = vocab
        self.embed_dim = embed_dim
        self.num_filters = num_filters
        self.kernel_sizes = kernel_sizes
        self.output_size = output_size
        self.use_dropout = use_dropout
        self.dropout_prob = dropout_prob

        # embedding layer
        self.embed = nn.Embedding(len(vocab), embed_dim)

        # a series of 1d convs
        self.convs_1d = nn.ModuleList([
            nn.Conv2d(1, num_filters, (k, embed_dim), padding=(k-2,0)) 
            for k in kernel_sizes])

        #dropout layer
        self.dropout = nn.Dropout(p=dropout_prob)

        self.dense = nn.Linear(len(kernel_sizes) * num_filters, output_size)
        self.sigmoid = nn.Sigmoid()

    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3)
        x_max = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x_max
        
    def forward(self, x):
        embeds = self.embed(x)
        embeds = embeds.unsqueeze(1)

        convs = [self.conv_and_pool(embeds, conv) for conv in self.convs_1d]

        x = torch.cat(convs, 1)
        x = self.dropout(x)

        logit = self.dense(x)

        return self.sigmoid(logit)

The output from the CPU

Epoch: 0, Loss: 0.5930776596069336, Validation Accuracy: 0.6538461538461539
Epoch: 0, Loss: 0.5337628722190857, Validation Accuracy: 0.7062937062937062
Epoch: 0, Loss: 0.8541656732559204, Validation Accuracy: 0.6783216783216783
Epoch: 0, Loss: 1.5076597929000854, Validation Accuracy: 0.7412587412587412
Epoch: 0, Loss: 1.055909514427185, Validation Accuracy: 0.7727272727272727
Epoch: 0, Loss: 0.4178897440433502, Validation Accuracy: 0.7797202797202797
Epoch: 0, Loss: 0.4650037884712219, Validation Accuracy: 0.7867132867132867
Epoch: 0, Loss: 0.3287246823310852, Validation Accuracy: 0.7867132867132867
Epoch: 1, Loss: 0.09397178143262863, Validation Accuracy: 0.7972027972027972
Epoch: 1, Loss: 0.355461061000824, Validation Accuracy: 0.7657342657342657
Epoch: 1, Loss: 0.05974849686026573, Validation Accuracy: 0.7517482517482518
Epoch: 1, Loss: 0.34798571467399597, Validation Accuracy: 0.7867132867132867
Epoch: 1, Loss: 0.21483170986175537, Validation Accuracy: 0.7797202797202797
Epoch: 1, Loss: 0.0788840726017952, Validation Accuracy: 0.7307692307692307
Epoch: 1, Loss: 0.07168418169021606, Validation Accuracy: 0.8006993006993007
Epoch: 1, Loss: 0.1160748153924942, Validation Accuracy: 0.8041958041958042
Epoch: 2, Loss: 0.05112272500991821, Validation Accuracy: 0.7937062937062938
Epoch: 2, Loss: 0.20592114329338074, Validation Accuracy: 0.7937062937062938
Epoch: 2, Loss: 0.05088387429714203, Validation Accuracy: 0.7937062937062938
Epoch: 2, Loss: 0.0787544697523117, Validation Accuracy: 0.8111888111888111
Epoch: 2, Loss: 0.5769421458244324, Validation Accuracy: 0.7587412587412588
Epoch: 2, Loss: 0.0004902268410660326, Validation Accuracy: 0.7972027972027972
Epoch: 2, Loss: 0.4308287501335144, Validation Accuracy: 0.7867132867132867
Epoch: 2, Loss: 0.007755571510642767, Validation Accuracy: 0.7902097902097902
Epoch: 3, Loss: 0.0071701593697071075, Validation Accuracy: 0.7622377622377622
Epoch: 3, Loss: 0.013051643036305904, Validation Accuracy: 0.7727272727272727
Epoch: 3, Loss: 0.007425243500620127, Validation Accuracy: 0.7412587412587412

and an example CUDA run’s output

Epoch: 0, Loss: 0.6711463928222656, Validation Accuracy: 0.6538461538461539
Epoch: 0, Loss: 0.7122050523757935, Validation Accuracy: 0.458041958041958
Epoch: 0, Loss: 0.724312961101532, Validation Accuracy: 0.6188811188811189
Epoch: 0, Loss: 0.7947384715080261, Validation Accuracy: 0.6258741258741258
Epoch: 0, Loss: 0.688117265701294, Validation Accuracy: 0.6188811188811189
Epoch: 0, Loss: 0.5511382222175598, Validation Accuracy: 0.6503496503496503
Epoch: 0, Loss: 0.6020643711090088, Validation Accuracy: 0.6538461538461539
Epoch: 0, Loss: 0.6449810266494751, Validation Accuracy: 0.6643356643356644
Epoch: 1, Loss: 0.645209550857544, Validation Accuracy: 0.6713286713286714
Epoch: 1, Loss: 0.6702262163162231, Validation Accuracy: 0.6643356643356644
Epoch: 1, Loss: 0.7097079753875732, Validation Accuracy: 0.6538461538461539
Epoch: 1, Loss: 0.7172590494155884, Validation Accuracy: 0.6538461538461539
Epoch: 1, Loss: 0.6835118532180786, Validation Accuracy: 0.6363636363636364
Epoch: 1, Loss: 0.5769379734992981, Validation Accuracy: 0.6538461538461539
Epoch: 1, Loss: 0.6344826817512512, Validation Accuracy: 0.6468531468531469
Epoch: 1, Loss: 0.6470801830291748, Validation Accuracy: 0.6433566433566433
Epoch: 2, Loss: 0.6567261815071106, Validation Accuracy: 0.6363636363636364
Epoch: 2, Loss: 0.6605111360549927, Validation Accuracy: 0.6433566433566433
Epoch: 2, Loss: 0.716945469379425, Validation Accuracy: 0.6433566433566433
Epoch: 2, Loss: 0.7049784660339355, Validation Accuracy: 0.6503496503496503
Epoch: 2, Loss: 0.6796182990074158, Validation Accuracy: 0.6538461538461539
Epoch: 2, Loss: 0.571618914604187, Validation Accuracy: 0.6538461538461539
Epoch: 2, Loss: 0.6348196268081665, Validation Accuracy: 0.6538461538461539
Epoch: 2, Loss: 0.6413698196411133, Validation Accuracy: 0.6468531468531469
Epoch: 3, Loss: 0.6526831388473511, Validation Accuracy: 0.6468531468531469
Epoch: 3, Loss: 0.6591930389404297, Validation Accuracy: 0.6538461538461539
Epoch: 3, Loss: 0.7166335582733154, Validation Accuracy: 0.6643356643356644
Epoch: 3, Loss: 0.7043871879577637, Validation Accuracy: 0.6643356643356644
Epoch: 3, Loss: 0.6688210964202881, Validation Accuracy: 0.6573426573426573
Epoch: 3, Loss: 0.5581091642379761, Validation Accuracy: 0.6643356643356644
Epoch: 3, Loss: 0.6201637387275696, Validation Accuracy: 0.6433566433566433
Epoch: 3, Loss: 0.6303338408470154, Validation Accuracy: 0.6433566433566433
Epoch: 4, Loss: 0.6430555582046509, Validation Accuracy: 0.6433566433566433
Epoch: 4, Loss: 0.6545212864875793, Validation Accuracy: 0.6468531468531469
Epoch: 4, Loss: 0.710310161113739, Validation Accuracy: 0.6503496503496503
Epoch: 4, Loss: 0.7034661173820496, Validation Accuracy: 0.6538461538461539
Epoch: 4, Loss: 0.6596505641937256, Validation Accuracy: 0.6573426573426573
Epoch: 4, Loss: 0.5458865165710449, Validation Accuracy: 0.6538461538461539
Epoch: 4, Loss: 0.6087003350257874, Validation Accuracy: 0.6468531468531469
Epoch: 4, Loss: 0.617493748664856, Validation Accuracy: 0.6363636363636364
Epoch: 5, Loss: 0.6346374750137329, Validation Accuracy: 0.6433566433566433
Epoch: 5, Loss: 0.6527444124221802, Validation Accuracy: 0.6433566433566433
Epoch: 5, Loss: 0.7063688039779663, Validation Accuracy: 0.6538461538461539
Epoch: 5, Loss: 0.7048066854476929, Validation Accuracy: 0.6538461538461539
Epoch: 5, Loss: 0.6533452272415161, Validation Accuracy: 0.6538461538461539
Epoch: 5, Loss: 0.5325800180435181, Validation Accuracy: 0.6573426573426573
Epoch: 5, Loss: 0.5976376533508301, Validation Accuracy: 0.6503496503496503
Epoch: 5, Loss: 0.6109954118728638, Validation Accuracy: 0.6503496503496503
Epoch: 6, Loss: 0.626984715461731, Validation Accuracy: 0.6433566433566433
Epoch: 6, Loss: 0.6500746011734009, Validation Accuracy: 0.6468531468531469
Epoch: 6, Loss: 0.7023633718490601, Validation Accuracy: 0.6503496503496503
Epoch: 6, Loss: 0.7072952389717102, Validation Accuracy: 0.6573426573426573
Epoch: 6, Loss: 0.6468454003334045, Validation Accuracy: 0.6573426573426573
Epoch: 6, Loss: 0.5176084637641907, Validation Accuracy: 0.6573426573426573
Epoch: 6, Loss: 0.5881103277206421, Validation Accuracy: 0.6573426573426573
Epoch: 6, Loss: 0.602046549320221, Validation Accuracy: 0.6468531468531469
Epoch: 7, Loss: 0.6179436445236206, Validation Accuracy: 0.6433566433566433
Epoch: 7, Loss: 0.6477866768836975, Validation Accuracy: 0.6468531468531469
Epoch: 7, Loss: 0.7008408904075623, Validation Accuracy: 0.6608391608391608
Epoch: 7, Loss: 0.7075467109680176, Validation Accuracy: 0.6573426573426573
Epoch: 7, Loss: 0.6404005289077759, Validation Accuracy: 0.6573426573426573
Epoch: 7, Loss: 0.5072473287582397, Validation Accuracy: 0.6608391608391608
Epoch: 7, Loss: 0.5786412358283997, Validation Accuracy: 0.6573426573426573
Epoch: 7, Loss: 0.5931527614593506, Validation Accuracy: 0.6433566433566433
Epoch: 8, Loss: 0.6106604337692261, Validation Accuracy: 0.6433566433566433
Epoch: 8, Loss: 0.6467059850692749, Validation Accuracy: 0.6433566433566433
Epoch: 8, Loss: 0.6986489295959473, Validation Accuracy: 0.6573426573426573
Epoch: 8, Loss: 0.7110292911529541, Validation Accuracy: 0.6538461538461539
Epoch: 8, Loss: 0.6344634294509888, Validation Accuracy: 0.6538461538461539
Epoch: 8, Loss: 0.4974040985107422, Validation Accuracy: 0.6538461538461539
Epoch: 8, Loss: 0.5710303783416748, Validation Accuracy: 0.6573426573426573
Epoch: 8, Loss: 0.5836442708969116, Validation Accuracy: 0.6468531468531469
Epoch: 9, Loss: 0.6042808890342712, Validation Accuracy: 0.6468531468531469
Epoch: 9, Loss: 0.6457432508468628, Validation Accuracy: 0.6468531468531469
Epoch: 9, Loss: 0.6968148946762085, Validation Accuracy: 0.6573426573426573
Epoch: 9, Loss: 0.7122361660003662, Validation Accuracy: 0.6573426573426573
Epoch: 9, Loss: 0.6280542016029358, Validation Accuracy: 0.6573426573426573
Epoch: 9, Loss: 0.48882660269737244, Validation Accuracy: 0.6573426573426573
Epoch: 9, Loss: 0.5643866062164307, Validation Accuracy: 0.6608391608391608
Epoch: 9, Loss: 0.5753633379936218, Validation Accuracy: 0.6538461538461539
Epoch: 10, Loss: 0.5978206396102905, Validation Accuracy: 0.6538461538461539
Epoch: 10, Loss: 0.644151508808136, Validation Accuracy: 0.6468531468531469
Epoch: 10, Loss: 0.695034921169281, Validation Accuracy: 0.6573426573426573
Epoch: 10, Loss: 0.7126811742782593, Validation Accuracy: 0.6573426573426573
Epoch: 10, Loss: 0.6227348446846008, Validation Accuracy: 0.6573426573426573
Epoch: 10, Loss: 0.4803176522254944, Validation Accuracy: 0.6573426573426573
Epoch: 10, Loss: 0.5574368238449097, Validation Accuracy: 0.6573426573426573
Epoch: 10, Loss: 0.567111611366272, Validation Accuracy: 0.6503496503496503
Epoch: 11, Loss: 0.5913469791412354, Validation Accuracy: 0.6503496503496503
Epoch: 11, Loss: 0.6416419744491577, Validation Accuracy: 0.6538461538461539
Epoch: 11, Loss: 0.6932559013366699, Validation Accuracy: 0.6573426573426573
Epoch: 11, Loss: 0.7137739658355713, Validation Accuracy: 0.6573426573426573
Epoch: 11, Loss: 0.617378830909729, Validation Accuracy: 0.6538461538461539
Epoch: 11, Loss: 0.4722314774990082, Validation Accuracy: 0.6503496503496503
Epoch: 11, Loss: 0.5506305694580078, Validation Accuracy: 0.6503496503496503
Epoch: 11, Loss: 0.5594512224197388, Validation Accuracy: 0.6538461538461539
Epoch: 12, Loss: 0.5861461758613586, Validation Accuracy: 0.6503496503496503
Epoch: 12, Loss: 0.6391829252243042, Validation Accuracy: 0.6503496503496503
Epoch: 12, Loss: 0.690549373626709, Validation Accuracy: 0.6538461538461539
Epoch: 12, Loss: 0.7145882248878479, Validation Accuracy: 0.6503496503496503
Epoch: 12, Loss: 0.6126421689987183, Validation Accuracy: 0.6573426573426573
Epoch: 12, Loss: 0.4649009108543396, Validation Accuracy: 0.6573426573426573
Epoch: 12, Loss: 0.5442847013473511, Validation Accuracy: 0.6538461538461539
Epoch: 12, Loss: 0.5534595251083374, Validation Accuracy: 0.6468531468531469
Epoch: 13, Loss: 0.5814194679260254, Validation Accuracy: 0.6468531468531469
Epoch: 13, Loss: 0.6364545226097107, Validation Accuracy: 0.6433566433566433
Epoch: 13, Loss: 0.6902283430099487, Validation Accuracy: 0.6538461538461539
Epoch: 13, Loss: 0.7163193821907043, Validation Accuracy: 0.6503496503496503
Epoch: 13, Loss: 0.6077008247375488, Validation Accuracy: 0.6538461538461539
Epoch: 13, Loss: 0.4577617347240448, Validation Accuracy: 0.6573426573426573
Epoch: 13, Loss: 0.5393982529640198, Validation Accuracy: 0.6573426573426573
Epoch: 13, Loss: 0.5457847118377686, Validation Accuracy: 0.6433566433566433
Epoch: 14, Loss: 0.5763950347900391, Validation Accuracy: 0.6433566433566433
Epoch: 14, Loss: 0.6351915597915649, Validation Accuracy: 0.6398601398601399
Epoch: 14, Loss: 0.6895017623901367, Validation Accuracy: 0.6573426573426573
Epoch: 14, Loss: 0.7182782888412476, Validation Accuracy: 0.6538461538461539
Segmentation fault (core dumped)

Any thoughts on what could be causing this?

Probably unrelated, but could you post the code for get_valid_accuracy? I’m wondering if model.eval() isn’t called or if model.train() isn’t called after the validation step finishes.

Thanks, here it is

def get_valid_accuracy(valid_X, valid_y, model, batch_size=32):
    """
    Computes validation accuracy given a validation set.
    """
    model.eval()
    num_batches = valid_X.size(0) // batch_size
    all_preds = []
    for k in range(num_batches + 1):
        start_idx, end_idx = k * batch_size, (k + 1) * batch_size
        batch_X = valid_X[start_idx:end_idx, :]

        probs = model(batch_X).flatten()
        preds = torch.where(probs >= 0.5, 1, 0)
        all_preds += preds.tolist()

    all_preds = np.array(all_preds)
    valid_y = np.array(valid_y.cpu())
    acc = np.mean(all_preds == valid_y)
    return acc

I just tried adding a call to model.train() every time after get_valid_accuracy runs in the loop but still getting the same issue.

OK, it might be more worthwhile to go after the segmentation fault first. Is it possible to get a stack trace with something like

$ gdb --args python my_script.py
...
Reading symbols from python...done.
(gdb) run
...
(gdb) backtrace
...

?
Additionally, what is the output of nvidia-smi?

Here is the backtrace output

Thread 19 "python" received signal SIGSEGV, Segmentation fault.
[Switching to Thread 0x7ffeb98e6700 (LWP 94700)]
0x00007ffebb50ff05 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
(gdb) backtrace
#0  0x00007ffebb50ff05 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#1  0x00007ffebb448c07 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#2  0x00007ffebb573f7d in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#3  0x00007ffebb420e6d in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#4  0x00007ffebb4213ff in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#5  0x00007ffebb327f05 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#6  0x00007ffebb328052 in ?? () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#7  0x00007ffebb4ee87d in cuMemsetD8Async () from /usr/lib/x86_64-linux-gnu/libcuda.so.1
#8  0x00007ffefde3bd2e in ?? () from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libcudart-6d56b25a.so.11.0
#9  0x00007ffefde1989b in ?? () from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libcudart-6d56b25a.so.11.0
#10 0x00007ffefde57311 in cudaMemsetAsync ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libcudart-6d56b25a.so.11.0
#11 0x00007fff257dd155 in scaleFilter4d(cudnnContext*, cudnnFilter4dStruct*, void*, void const*) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#12 0x00007fff25abbc8b in cudnn::cnn::Wgrad2dAlgo0Engine<float, float, float>::execute_internal_impl(cudnn::backend::VariantPack const&, CUstream_st*) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#13 0x00007fff250386f3 in cudnn::cnn::EngineInterface::execute(cudnn::backend::VariantPack const&, CUstream_st*) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#14 0x00007fff25735750 in cudnn::cnn::EngineContainer<(cudnnBackendEngineName_t)2020, 4096ul>::execute_internal_impl(cudnn::backend::VariantPack const&, CUstream_st*) () from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
---Type <return> to continue, or q <return> to quit---
#15 0x00007fff250386f3 in cudnn::cnn::EngineInterface::execute(cudnn::backend::VariantPack const&, CUstream_st*) ()

   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so


#16 0x00007fff254aeebc in cudnn::cnn::AutoTransformationExecutor::execute_pipeline(cudnn::cnn::ConvolutionEngine&, cudnn::backend::VariantPack const&, CUstream_st*) const () from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so

#17 0x00007fff2575c4f1 in cudnn::cnn::GeneralizedConvolutionEngine<cudnn::cnn::EngineContainer<(cudnnBackendEngineName_t)2020, 4096ul> >::execute_internal_impl(cudnn::backend::VariantPack const&, CUstream_st*) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#18 0x00007fff250386f3 in cudnn::cnn::EngineInterface::execute(cudnn::backend::VariantPack const&, CUstream_st*) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#19 0x00007fff254330db in cudnn::backend::execute(cudnnContext*, cudnn::backend::ExecutionPlan&, cudnn::backend::VariantPack&) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#20 0x00007fff2572719d in cudnn::backend::EnginesAlgoMap<cudnnConvolutionBwdFilterAlgo_t, 7>::execute_wrapper(cudnnContext*, cudnnConvolutionBwdFilterAlgo_t, cudnn::backend::ExecutionPlan&, cudnn::backend::VariantPack&) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#21 0x00007fff25726d31 in cudnn::backend::cudnnConvolutionBackwardFilter(cudnnContext*, void const*, cudnnTensorStruct const*, void const*, cudnnTensorStruct const*, void const*, cudnnConvolutionStruct const*, cudnnConvolutionBwdFilterAlgo_t, void*, unsigned long, void const*, cudnnFilterStruct const*, void*)
    () from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#22 0x00007fff2504091a in cudnnConvolutionBackwardFilter ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#23 0x00007fff2403bfbd in at::native::raw_cudnn_convolution_backward_weight_out_32bit(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef---Type <return> to continue, or q <return> to quit---
<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool)::{lambda(cudnnConvolutionBwdFilterAlgoPerf_t const&)#1}::operator()(cudnnConvolutionBwdFilterAlgoPerf_t const&) const () from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#24 0x00007fff2403f01a in at::native::raw_cudnn_convolution_backward_weight_out_32bit(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#25 0x00007fff2403fcaf in at::native::raw_cudnn_convolution_backward_weight_out(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#26 0x00007fff240390fa in at::native::cudnn_convolution_backward_weight(char const*, c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#27 0x00007fff240397db in at::native::cudnn_convolution_backward_weight(c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#28 0x00007fff8637f6e2 in at::(anonymous namespace)::(anonymous namespace)::wrapper_cudnn_convolution_backward_weight(c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cu.so
#29 0x00007fff8637f771 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool), &at::(anonymous namespace)::(anonymous namespace)::wrapper_cudnn_convolution_backward_weight>, at::Tensor, c10::guts::typelist::typelist<c10::ArrayRef<long>, at::Tenso---Type <return> to continue, or q <return> to quit---
r const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool> >, at::Tensor (c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool)>::call(c10::OperatorKernel*, c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cu.so
#30 0x00007fff749d0485 in at::Tensor c10::Dispatcher::call<at::Tensor, c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool>(c10::TypedOperatorHandle<at::Tensor (c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool)> const&, c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) const ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so
#31 0x00007fff74865ea1 in at::cudnn_convolution_backward_weight(c10::ArrayRef<long>, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so
#32 0x00007fff240335c9 in at::native::cudnn_convolution_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 2ul>) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cpp.so
#33 0x00007fff8637f4f7 in at::(anonymous namespace)::(anonymous namespace)::wrapper_cudnn_convolution_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 2ul>) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cu.so
#34 0x00007fff8637f58a in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bo---Type <return> to continue, or q <return> to quit---
ol, bool, std::array<bool, 2ul>), &at::(anonymous namespace)::(anonymous namespace)::wrapper_cudnn_convolution_backward>, std::tuple<at::Tensor, at::Tensor>, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 2ul> > >, std::tuple<at::Tensor, at::Tensor> (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 2ul>)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 2ul>) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cuda_cu.so
#35 0x00007fff74866537 in at::cudnn_convolution_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 2ul>) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so
#36 0x00007fff760e379f in torch::autograd::VariableType::(anonymous namespace)::cudnn_convolution_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 2ul>) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so
#37 0x00007fff760e3eba in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 2ul>), &torch::autograd::VariableType::(anonymous namespace)::cudnn_convolution_backward>, std::tuple<at::Tensor, at::Tensor>, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 2ul> > >, std::tuple<at::Tensor, at::Tensor> (at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 2ul>)>::call(c10::OperatorKernel*, at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 2ul>) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so
---Type <return> to continue, or q <return> to quit---
#38 0x00007fff74866537 in at::cudnn_convolution_backward(at::Tensor const&, at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, long, bool, bool, bool, std::array<bool, 2ul>) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so
#39 0x00007fff75f4420c in torch::autograd::generated::CudnnConvolutionBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so
#40 0x00007fff765be771 in torch::autograd::Node::operator()(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so
#41 0x00007fff765ba57b in torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so
#42 0x00007fff765bb19f in torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so
#43 0x00007fff765b2979 in torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_cpu.so
#44 0x00007fffe6d3c163 in torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) ()
   from /home/ubuntu/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/lib/libtorch_python.so
#45 0x00007fffe7d486df in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#46 0x00007ffff7bbb6db in start_thread (arg=0x7ffeb98e6700) at pthread_create.c:463
#47 0x00007ffff6f3771f in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:95

Output of nvidia-smi is

ri Jun 11 19:14:32 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.119.03   Driver Version: 450.119.03   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla M60           On   | 00000000:00:1E.0 Off |                    0 |
| N/A   37C    P0    38W / 150W |      0MiB /  7618MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

Interesting, could you try and isolate this to a self-contained minimal training/val loop that still causes the segfault (it can have random training data if that works)? I will see if I can reproduce it on similar hardware.