Torch.nn.lstm lstm layer error in GPU

Hi, I’m having a problem specific to GPU. This code works in cpu, but yields “Child terminated with signal 11” when executed in GPU
The class I have is as the following:

class CustomLSTM(torch.nn.Module):
    def __init__(self, input_size=100, embedding_dict_size=466553, batch_size = 25, hidden_size=30, output_dim=2, num_layers=1):
        super(YelpLSTM, self).__init__()
        # Attributes
        self.input_size = input_size # Input size is the length of comment (how many words)
        self.embedding_dict_size = embedding_dict_size # Size of word dictionary
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.output_dim = output_dim
        self.num_layers = num_layers

        # Embedding layer
        self.embedding = torch.nn.Embedding(num_embeddings=self.embedding_dict_size, embedding_dim=self.input_size)

        # Construct lstm
        self.lstm = torch.nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers)

        # Add linear layer for the end
        self.linear = torch.nn.Linear(self.hidden_size, self.output_dim)

    def forward(self, X):
        X = self.embedding(X)
        trans_X = X.transpose(0, 1) # Make it to [sequence length, batch size, input_size]
        hidden_state = torch.zeros(1, len(X), self.hidden_size)
        cell_state = torch.zeros(1, len(X), self.hidden_size)
        outputs, (_, _) = self.lstm(trans_X, (hidden_state, cell_state)) # Something went wrong here from debugging, it's the last line that got executed.
        outputs = self.linear(outputs[-1].view(self.batch_size, -1)) # [Batch x (30*5000)]
        outputs = torch.nn.Sigmoid()(outputs)

        return outputs # Shape: [Batch, 2 (Length of out vector)]

Is there anyone who can help, or has resolved this problem before?

Thank you

Could you post the stack trace from gdb via:

$ gdb --args python my_script.py
...
Reading symbols from python...done.
(gdb) run
...
(gdb) backtrace
...

Also, could you post a reproducible code snippet (use random tensors as the inputs) as well as the PyTorch, CUDA, and cudnn versions you are using?

Hi @ptrblck ,

Thank you for responding. I am unable to use the gdb --args python script command as it outputs the error:

"PATH/TO/PYTHON/python": not in executable format: File format not recognized

However, for the other information you asked, I do have the following:

  • PyTorch version: 0.4.1
  • CUDA version: 9.1.85
  • cudnn version: 7005

Here’s the code:

#!/usr/bin/env python

# Construct model
print('LSTMERS Project: Yelp Sentiment Analysis using LSTM')

import torch
import numpy as np
import torch
import numpy as np

class YelpLSTM(torch.nn.Module):
    def __init__(self, input_size=100, embedding_dict_size=466553, batch_size = 10, hidden_size=30, output_dim=2, num_layers=1):
        super(YelpLSTM, self).__init__()
        # Attributes
        self.input_size = input_size # Input size is the length of comment (how many words)
        self.embedding_dict_size = embedding_dict_size # Size of word dictionary
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.output_dim = output_dim
        self.num_layers = num_layers
        
        # Embedding layer
        self.embedding = torch.nn.Embedding(num_embeddings=self.embedding_dict_size, embedding_dim=self.input_size)
        
        # Construct lstm
        self.lstm = torch.nn.LSTM(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers)
        
        # Add linear layer for the end
        self.linear = torch.nn.Linear(self.hidden_size, self.output_dim)
    
    def forward(self, X):
        X = self.embedding(X)
        
        trans_X = X.transpose(0, 1) # Make it to [sequence length, batch size, input_size]
        
        hidden_state = torch.zeros(1, len(X), self.hidden_size)
        cell_state = torch.zeros(1, len(X), self.hidden_size)
        
        outputs, (_, _) = self.lstm(trans_X, (hidden_state, cell_state))
        outputs = self.linear(outputs[-1].view(self.batch_size, -1)) # [Batch x (30*5000)]
        outputs = torch.nn.Sigmoid()(outputs)
        
        return outputs # Shape: [Batch, 2 (Length of out vector)]
# Define, train, eval model

# Define model and optimizer
batch_size = 25
device = ('cuda' if torch.cuda.is_available() else 'cpu')

model = YelpLSTM(batch_size = 25)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

criterion = torch.nn.CrossEntropyLoss()


print('Loading dataset...')
# Make custom datasets
from torch.utils.data import DataLoader, Dataset
class CustomDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
        
    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]
    
# Paths for saved dataloaders
# train_loader_path = '../dataset/tensor_datasets/train_loader.torch'
# test_loader_path = '../dataset/tensor_datasets/test_loader.torch'

# Load train and test object
# train_loader = torch.load(train_loader_path)
# test_loader = torch.load(test_loader_path)

# For reproducible error
random_X, random_Y = torch.zeros(100,100).long(), torch.zeros(100).long()
random_dataset = CustomDataset(random_X, random_Y)
train_loader = DataLoader(random_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(random_dataset, batch_size=batch_size, shuffle=False)

print('Training and evaluating model...')
# Train and evaluate model
epochs = 10
for epoch in range(epochs):
    print('Epoch: '+ str(epoch))
    
    # Training
    model.train()
    train_batches = len(train_loader)
    epoch_acc = 0.0
    epoch_loss = 0.0

    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        
        X_batch, Y_batch = data
        X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
        
        # Forward
        y_pred = model(X_batch)
        print(Y_batch)
        loss = criterion(y_pred, Y_batch.long())
        
        # Get stats
        right_count = torch.sum(Y_batch.cpu() == torch.argmax(y_pred, 1).long()).cpu().item()
        batch_acc = right_count / batch_size
        epoch_acc += batch_acc
        epoch_loss += loss

        # Backwards
        loss.backward()
        optimizer.step()
    print(len(train_loader) / batch_size)
    epoch_acc /= train_batches
    epoch_loss /= train_batches
    print('Epoch: ' + str(epoch) + ', training loss: ' + str(epoch_loss.item()) + ', training accuracy: ' + str(epoch_acc))
    
    # Evaluate model
    model.eval()
    test_batches = len(test_loader)
    epoch_acc = 0.0
    epoch_loss = 0.0

    for i, data in enumerate(test_loader):
        # Define batches
        X_batch, Y_batch = data
        X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)

        # Forward
        y_pred = model(X_batch)
        loss = criterion(y_pred, Y_batch.long())
        
        # Get stats
        right_count = torch.sum(Y_batch.cpu() == torch.argmax(y_pred, 1).long()).cpu().item()
        batch_acc = right_count / batch_size
        epoch_acc += batch_acc
        epoch_loss += loss

    
    epoch_acc /= test_batches
    epoch_loss /= test_batches
    print('Epoch: ' + str(epoch) + ', evaluation loss: ' + str(epoch_loss.item()) + ', evaluation accuracy: ' + str(epoch_acc))

Thank you

Oh, PyTorch 0.4.1 was released in July 2018, so could you update to the latest stable release (1.4.0) or the nightly binaries and rerun the code again?

I see, I will try to install the newer release then. Will update!

Hi @ptrblck, I tested it on the newer version of PyTorch (1.4.0), but still unable to run it. I ran it on google colab this time, so I cannot see exactly what the error is, but it keeps crashing on GPU although it runs ok on CPU.

Thanks for the code! :slight_smile:

I got an error of a device mismatch, since hidden_state and cell_state are both initialized on the CPU even if you push the model to the GPU.
Could you try to use:

    def forward(self, X):
        X = self.embedding(X)
        
        trans_X = X.transpose(0, 1) # Make it to [sequence length, batch size, input_size]
        
        hidden_state = torch.zeros(1, len(X), self.hidden_size).to(X.device)
        cell_state = torch.zeros(1, len(X), self.hidden_size).to(X.device)
...

Also, the right_count calculation will raise another device mismatch.
You would have to call .cpu() on the torch.argmax operation, while it;s called on the sum in your code:

right_count = torch.sum(Y_batch.cpu() == torch.argmax(y_pred, 1).long().cpu()).item()

After fixing these issues, the code runs fine.

Let me know, if that helps.

2 Likes

Thank you! It worked!