I’m getting very weird behavior when training a basic CNN text classification model with Cuda v11.1 on an AWS machine. I have tried this on both p2 and g3 instance types. When training on the CPU, the model does very well and achieves about 85% validation accuracy after ~15 epochs. However, using the exact same code with CUDA gives strange results. I get random segmentation faults during the training process. In addition, the validation accuracy barely changes between batches even though the loss does change, and it hovers around 65%. Here is my code
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.utils import shuffle
from cnn import BasicConvModel
from data import (
load_data,
preprocess_text,
create_vocabulary,
get_valid_accuracy
)
torch.manual_seed(12)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('cpu')
print(device)
output_size = 1
embed_dim = 600
num_filters = 200
kernel_sizes = [2, 3, 4, 5]
batch_size = 16
X, y = load_data('../data/classification_sample.csv')
processed_text = preprocess_text(X)
vocab = create_vocabulary(processed_text)
model = BasicConvModel(vocab, embed_dim, num_filters, kernel_sizes, output_size)
if torch.cuda.is_available():
model.cuda()
loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
processed_text, y = shuffle(processed_text, y, random_state=132)
max_idx = max([len(text) for text in processed_text])
text_idxs, pad_idx = [], vocab['<PAD>']
for text in processed_text:
to_idx = [vocab[tok] for tok in text]
for i in range(max_idx - len(text)):
to_idx.append(pad_idx)
text_idxs.append(to_idx)
full_tensor = torch.tensor(text_idxs, device=device)
y = torch.tensor(y, device=device)
num_epochs = 20
valid_idx = int(0.2 * full_tensor.size(0))
X, valid_X = full_tensor[valid_idx:, :], full_tensor[:valid_idx, :]
y, valid_y = y[valid_idx:], y[:valid_idx]
num_batches = X.size(0) // batch_size
print_every = 10
model.train()
for i in range(num_epochs):
for j in range(num_batches + 1):
start_idx, end_idx = j * batch_size, (j + 1) * batch_size
batch_X, batch_y = X[start_idx:end_idx, :], y[start_idx:end_idx]
preds = model(batch_X)
loss = loss_fn(preds.squeeze(), batch_y.float())
loss.backward()
optimizer.step()
model.zero_grad()
if j % print_every == 0:
acc = get_valid_accuracy(valid_X, valid_y, model)
print(f"Epoch: {i}, Loss: {loss}, Validation Accuracy: {acc}")
And the CNN model
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class BasicConvModel(nn.Module):
def __init__(self, vocab, embed_dim, num_filters, kernel_sizes,
output_size, use_dropout=True, dropout_prob=0.2):
super().__init__()
self.vocab = vocab
self.embed_dim = embed_dim
self.num_filters = num_filters
self.kernel_sizes = kernel_sizes
self.output_size = output_size
self.use_dropout = use_dropout
self.dropout_prob = dropout_prob
# embedding layer
self.embed = nn.Embedding(len(vocab), embed_dim)
# a series of 1d convs
self.convs_1d = nn.ModuleList([
nn.Conv2d(1, num_filters, (k, embed_dim), padding=(k-2,0))
for k in kernel_sizes])
#dropout layer
self.dropout = nn.Dropout(p=dropout_prob)
self.dense = nn.Linear(len(kernel_sizes) * num_filters, output_size)
self.sigmoid = nn.Sigmoid()
def conv_and_pool(self, x, conv):
x = F.relu(conv(x)).squeeze(3)
x_max = F.max_pool1d(x, x.size(2)).squeeze(2)
return x_max
def forward(self, x):
embeds = self.embed(x)
embeds = embeds.unsqueeze(1)
convs = [self.conv_and_pool(embeds, conv) for conv in self.convs_1d]
x = torch.cat(convs, 1)
x = self.dropout(x)
logit = self.dense(x)
return self.sigmoid(logit)
The output from the CPU
Epoch: 0, Loss: 0.5930776596069336, Validation Accuracy: 0.6538461538461539
Epoch: 0, Loss: 0.5337628722190857, Validation Accuracy: 0.7062937062937062
Epoch: 0, Loss: 0.8541656732559204, Validation Accuracy: 0.6783216783216783
Epoch: 0, Loss: 1.5076597929000854, Validation Accuracy: 0.7412587412587412
Epoch: 0, Loss: 1.055909514427185, Validation Accuracy: 0.7727272727272727
Epoch: 0, Loss: 0.4178897440433502, Validation Accuracy: 0.7797202797202797
Epoch: 0, Loss: 0.4650037884712219, Validation Accuracy: 0.7867132867132867
Epoch: 0, Loss: 0.3287246823310852, Validation Accuracy: 0.7867132867132867
Epoch: 1, Loss: 0.09397178143262863, Validation Accuracy: 0.7972027972027972
Epoch: 1, Loss: 0.355461061000824, Validation Accuracy: 0.7657342657342657
Epoch: 1, Loss: 0.05974849686026573, Validation Accuracy: 0.7517482517482518
Epoch: 1, Loss: 0.34798571467399597, Validation Accuracy: 0.7867132867132867
Epoch: 1, Loss: 0.21483170986175537, Validation Accuracy: 0.7797202797202797
Epoch: 1, Loss: 0.0788840726017952, Validation Accuracy: 0.7307692307692307
Epoch: 1, Loss: 0.07168418169021606, Validation Accuracy: 0.8006993006993007
Epoch: 1, Loss: 0.1160748153924942, Validation Accuracy: 0.8041958041958042
Epoch: 2, Loss: 0.05112272500991821, Validation Accuracy: 0.7937062937062938
Epoch: 2, Loss: 0.20592114329338074, Validation Accuracy: 0.7937062937062938
Epoch: 2, Loss: 0.05088387429714203, Validation Accuracy: 0.7937062937062938
Epoch: 2, Loss: 0.0787544697523117, Validation Accuracy: 0.8111888111888111
Epoch: 2, Loss: 0.5769421458244324, Validation Accuracy: 0.7587412587412588
Epoch: 2, Loss: 0.0004902268410660326, Validation Accuracy: 0.7972027972027972
Epoch: 2, Loss: 0.4308287501335144, Validation Accuracy: 0.7867132867132867
Epoch: 2, Loss: 0.007755571510642767, Validation Accuracy: 0.7902097902097902
Epoch: 3, Loss: 0.0071701593697071075, Validation Accuracy: 0.7622377622377622
Epoch: 3, Loss: 0.013051643036305904, Validation Accuracy: 0.7727272727272727
Epoch: 3, Loss: 0.007425243500620127, Validation Accuracy: 0.7412587412587412
and an example CUDA run’s output
Epoch: 0, Loss: 0.6711463928222656, Validation Accuracy: 0.6538461538461539
Epoch: 0, Loss: 0.7122050523757935, Validation Accuracy: 0.458041958041958
Epoch: 0, Loss: 0.724312961101532, Validation Accuracy: 0.6188811188811189
Epoch: 0, Loss: 0.7947384715080261, Validation Accuracy: 0.6258741258741258
Epoch: 0, Loss: 0.688117265701294, Validation Accuracy: 0.6188811188811189
Epoch: 0, Loss: 0.5511382222175598, Validation Accuracy: 0.6503496503496503
Epoch: 0, Loss: 0.6020643711090088, Validation Accuracy: 0.6538461538461539
Epoch: 0, Loss: 0.6449810266494751, Validation Accuracy: 0.6643356643356644
Epoch: 1, Loss: 0.645209550857544, Validation Accuracy: 0.6713286713286714
Epoch: 1, Loss: 0.6702262163162231, Validation Accuracy: 0.6643356643356644
Epoch: 1, Loss: 0.7097079753875732, Validation Accuracy: 0.6538461538461539
Epoch: 1, Loss: 0.7172590494155884, Validation Accuracy: 0.6538461538461539
Epoch: 1, Loss: 0.6835118532180786, Validation Accuracy: 0.6363636363636364
Epoch: 1, Loss: 0.5769379734992981, Validation Accuracy: 0.6538461538461539
Epoch: 1, Loss: 0.6344826817512512, Validation Accuracy: 0.6468531468531469
Epoch: 1, Loss: 0.6470801830291748, Validation Accuracy: 0.6433566433566433
Epoch: 2, Loss: 0.6567261815071106, Validation Accuracy: 0.6363636363636364
Epoch: 2, Loss: 0.6605111360549927, Validation Accuracy: 0.6433566433566433
Epoch: 2, Loss: 0.716945469379425, Validation Accuracy: 0.6433566433566433
Epoch: 2, Loss: 0.7049784660339355, Validation Accuracy: 0.6503496503496503
Epoch: 2, Loss: 0.6796182990074158, Validation Accuracy: 0.6538461538461539
Epoch: 2, Loss: 0.571618914604187, Validation Accuracy: 0.6538461538461539
Epoch: 2, Loss: 0.6348196268081665, Validation Accuracy: 0.6538461538461539
Epoch: 2, Loss: 0.6413698196411133, Validation Accuracy: 0.6468531468531469
Epoch: 3, Loss: 0.6526831388473511, Validation Accuracy: 0.6468531468531469
Epoch: 3, Loss: 0.6591930389404297, Validation Accuracy: 0.6538461538461539
Epoch: 3, Loss: 0.7166335582733154, Validation Accuracy: 0.6643356643356644
Epoch: 3, Loss: 0.7043871879577637, Validation Accuracy: 0.6643356643356644
Epoch: 3, Loss: 0.6688210964202881, Validation Accuracy: 0.6573426573426573
Epoch: 3, Loss: 0.5581091642379761, Validation Accuracy: 0.6643356643356644
Epoch: 3, Loss: 0.6201637387275696, Validation Accuracy: 0.6433566433566433
Epoch: 3, Loss: 0.6303338408470154, Validation Accuracy: 0.6433566433566433
Epoch: 4, Loss: 0.6430555582046509, Validation Accuracy: 0.6433566433566433
Epoch: 4, Loss: 0.6545212864875793, Validation Accuracy: 0.6468531468531469
Epoch: 4, Loss: 0.710310161113739, Validation Accuracy: 0.6503496503496503
Epoch: 4, Loss: 0.7034661173820496, Validation Accuracy: 0.6538461538461539
Epoch: 4, Loss: 0.6596505641937256, Validation Accuracy: 0.6573426573426573
Epoch: 4, Loss: 0.5458865165710449, Validation Accuracy: 0.6538461538461539
Epoch: 4, Loss: 0.6087003350257874, Validation Accuracy: 0.6468531468531469
Epoch: 4, Loss: 0.617493748664856, Validation Accuracy: 0.6363636363636364
Epoch: 5, Loss: 0.6346374750137329, Validation Accuracy: 0.6433566433566433
Epoch: 5, Loss: 0.6527444124221802, Validation Accuracy: 0.6433566433566433
Epoch: 5, Loss: 0.7063688039779663, Validation Accuracy: 0.6538461538461539
Epoch: 5, Loss: 0.7048066854476929, Validation Accuracy: 0.6538461538461539
Epoch: 5, Loss: 0.6533452272415161, Validation Accuracy: 0.6538461538461539
Epoch: 5, Loss: 0.5325800180435181, Validation Accuracy: 0.6573426573426573
Epoch: 5, Loss: 0.5976376533508301, Validation Accuracy: 0.6503496503496503
Epoch: 5, Loss: 0.6109954118728638, Validation Accuracy: 0.6503496503496503
Epoch: 6, Loss: 0.626984715461731, Validation Accuracy: 0.6433566433566433
Epoch: 6, Loss: 0.6500746011734009, Validation Accuracy: 0.6468531468531469
Epoch: 6, Loss: 0.7023633718490601, Validation Accuracy: 0.6503496503496503
Epoch: 6, Loss: 0.7072952389717102, Validation Accuracy: 0.6573426573426573
Epoch: 6, Loss: 0.6468454003334045, Validation Accuracy: 0.6573426573426573
Epoch: 6, Loss: 0.5176084637641907, Validation Accuracy: 0.6573426573426573
Epoch: 6, Loss: 0.5881103277206421, Validation Accuracy: 0.6573426573426573
Epoch: 6, Loss: 0.602046549320221, Validation Accuracy: 0.6468531468531469
Epoch: 7, Loss: 0.6179436445236206, Validation Accuracy: 0.6433566433566433
Epoch: 7, Loss: 0.6477866768836975, Validation Accuracy: 0.6468531468531469
Epoch: 7, Loss: 0.7008408904075623, Validation Accuracy: 0.6608391608391608
Epoch: 7, Loss: 0.7075467109680176, Validation Accuracy: 0.6573426573426573
Epoch: 7, Loss: 0.6404005289077759, Validation Accuracy: 0.6573426573426573
Epoch: 7, Loss: 0.5072473287582397, Validation Accuracy: 0.6608391608391608
Epoch: 7, Loss: 0.5786412358283997, Validation Accuracy: 0.6573426573426573
Epoch: 7, Loss: 0.5931527614593506, Validation Accuracy: 0.6433566433566433
Epoch: 8, Loss: 0.6106604337692261, Validation Accuracy: 0.6433566433566433
Epoch: 8, Loss: 0.6467059850692749, Validation Accuracy: 0.6433566433566433
Epoch: 8, Loss: 0.6986489295959473, Validation Accuracy: 0.6573426573426573
Epoch: 8, Loss: 0.7110292911529541, Validation Accuracy: 0.6538461538461539
Epoch: 8, Loss: 0.6344634294509888, Validation Accuracy: 0.6538461538461539
Epoch: 8, Loss: 0.4974040985107422, Validation Accuracy: 0.6538461538461539
Epoch: 8, Loss: 0.5710303783416748, Validation Accuracy: 0.6573426573426573
Epoch: 8, Loss: 0.5836442708969116, Validation Accuracy: 0.6468531468531469
Epoch: 9, Loss: 0.6042808890342712, Validation Accuracy: 0.6468531468531469
Epoch: 9, Loss: 0.6457432508468628, Validation Accuracy: 0.6468531468531469
Epoch: 9, Loss: 0.6968148946762085, Validation Accuracy: 0.6573426573426573
Epoch: 9, Loss: 0.7122361660003662, Validation Accuracy: 0.6573426573426573
Epoch: 9, Loss: 0.6280542016029358, Validation Accuracy: 0.6573426573426573
Epoch: 9, Loss: 0.48882660269737244, Validation Accuracy: 0.6573426573426573
Epoch: 9, Loss: 0.5643866062164307, Validation Accuracy: 0.6608391608391608
Epoch: 9, Loss: 0.5753633379936218, Validation Accuracy: 0.6538461538461539
Epoch: 10, Loss: 0.5978206396102905, Validation Accuracy: 0.6538461538461539
Epoch: 10, Loss: 0.644151508808136, Validation Accuracy: 0.6468531468531469
Epoch: 10, Loss: 0.695034921169281, Validation Accuracy: 0.6573426573426573
Epoch: 10, Loss: 0.7126811742782593, Validation Accuracy: 0.6573426573426573
Epoch: 10, Loss: 0.6227348446846008, Validation Accuracy: 0.6573426573426573
Epoch: 10, Loss: 0.4803176522254944, Validation Accuracy: 0.6573426573426573
Epoch: 10, Loss: 0.5574368238449097, Validation Accuracy: 0.6573426573426573
Epoch: 10, Loss: 0.567111611366272, Validation Accuracy: 0.6503496503496503
Epoch: 11, Loss: 0.5913469791412354, Validation Accuracy: 0.6503496503496503
Epoch: 11, Loss: 0.6416419744491577, Validation Accuracy: 0.6538461538461539
Epoch: 11, Loss: 0.6932559013366699, Validation Accuracy: 0.6573426573426573
Epoch: 11, Loss: 0.7137739658355713, Validation Accuracy: 0.6573426573426573
Epoch: 11, Loss: 0.617378830909729, Validation Accuracy: 0.6538461538461539
Epoch: 11, Loss: 0.4722314774990082, Validation Accuracy: 0.6503496503496503
Epoch: 11, Loss: 0.5506305694580078, Validation Accuracy: 0.6503496503496503
Epoch: 11, Loss: 0.5594512224197388, Validation Accuracy: 0.6538461538461539
Epoch: 12, Loss: 0.5861461758613586, Validation Accuracy: 0.6503496503496503
Epoch: 12, Loss: 0.6391829252243042, Validation Accuracy: 0.6503496503496503
Epoch: 12, Loss: 0.690549373626709, Validation Accuracy: 0.6538461538461539
Epoch: 12, Loss: 0.7145882248878479, Validation Accuracy: 0.6503496503496503
Epoch: 12, Loss: 0.6126421689987183, Validation Accuracy: 0.6573426573426573
Epoch: 12, Loss: 0.4649009108543396, Validation Accuracy: 0.6573426573426573
Epoch: 12, Loss: 0.5442847013473511, Validation Accuracy: 0.6538461538461539
Epoch: 12, Loss: 0.5534595251083374, Validation Accuracy: 0.6468531468531469
Epoch: 13, Loss: 0.5814194679260254, Validation Accuracy: 0.6468531468531469
Epoch: 13, Loss: 0.6364545226097107, Validation Accuracy: 0.6433566433566433
Epoch: 13, Loss: 0.6902283430099487, Validation Accuracy: 0.6538461538461539
Epoch: 13, Loss: 0.7163193821907043, Validation Accuracy: 0.6503496503496503
Epoch: 13, Loss: 0.6077008247375488, Validation Accuracy: 0.6538461538461539
Epoch: 13, Loss: 0.4577617347240448, Validation Accuracy: 0.6573426573426573
Epoch: 13, Loss: 0.5393982529640198, Validation Accuracy: 0.6573426573426573
Epoch: 13, Loss: 0.5457847118377686, Validation Accuracy: 0.6433566433566433
Epoch: 14, Loss: 0.5763950347900391, Validation Accuracy: 0.6433566433566433
Epoch: 14, Loss: 0.6351915597915649, Validation Accuracy: 0.6398601398601399
Epoch: 14, Loss: 0.6895017623901367, Validation Accuracy: 0.6573426573426573
Epoch: 14, Loss: 0.7182782888412476, Validation Accuracy: 0.6538461538461539
Segmentation fault (core dumped)
Any thoughts on what could be causing this?