CrossEntropyLoss on LSTM Autoencoder converges on CPU but not GPU

I have built a simple LSTM auto encoder for sequence data. I run it on CPU and I get a cross entropy loss that converges to a nice low value to 0.075. If I run this exact code on a machine with CUDA using the GPU, it bounces around 1.36 and never converges to a similarly low value as on the CPU.

This is my DataLoader:

class SequenceDataset(Dataset):
    def __init__(self, data_path,device):
        self.data_file = pd.read_csv(data_path)
        self.sequences = self.data_file["text"]
        self.labels = T.tensor(self.data_file["label"], dtype=T.float).to(device)
        self.n_examples = len(self.labels)
        self.tokens = T.zeros([self.n_examples, len(self.sequences[0])], dtype=T.long).to(device)

        tokenizer = get_tokenizer(None)

        def yield_tokens(data_iterator):
            for t in data_iterator:
                yield tokenizer(t)

        self.vocab = build_vocab_from_iterator(yield_tokens(" ".join(self.sequences)))

        for i,d in enumerate(self.sequences):
            self.tokens[i] = T.tensor(self.vocab(tokenizer(" ".join(d))))
    def __len__(self):
        return self.n_examples

    def __getitem__(self, item):
        return {'text' : self.tokens[item],'label': self.labels[item]}

This is my Autoencoder:

class Encoder(nn.Module):
    def __init__(self, n_features, hidden_size):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(4, n_features)
        self.lstm_enc = nn.LSTM(
            input_size = n_features,
            hidden_size = hidden_size,
            batch_first = True
    def forward(self,x):
        embedded = self.embedding(x)
        _, (h,_) = self.lstm_enc(embedded)
        enc = h.squeeze(dim=0)
        return enc

class Decoder(nn.Module):
    def __init__(self, n_features, hidden_size):
        self.lstm_dec = nn.LSTM(
            input_size = hidden_size,
            hidden_size = hidden_size,
            batch_first = True

        self.linearlayer = nn.Linear(hidden_size, n_features)

    def forward(self, x):
        enc = x.unsqueeze(1).repeat(1,69,1)
        out, (_,_) = self.lstm_dec(enc)
        out = self.linearlayer(out)
        return out

class LSTM_AutoEncoder(nn.Module):
    def __init__(self, n_features, hidden_size):
        super(LSTM_AutoEncoder, self).__init__()
        self.encoder = Encoder(n_features, hidden_size)
        self.decoder = Decoder(n_features, hidden_size)

    def forward(self, x):
        h = self.encoder(x)
        out = self.decoder(h)

        return out

This is my model call:

model = LSTM_AutoEncoder(n_features, hidden_size)
model =
best_model_wts = copy.deepcopy(model.state_dict())
optimizer = T.optim.Adam(model.parameters(),lr=1e-3)
criterion = nn.CrossEntropyLoss(reduction='mean').to(device)

And this is my training loop:

    history = dict(train=[])
    best_loss = 2
    SD = SequenceDataset(file, device) # Load Data
    DL = DataLoader(SD, batch_size=batch_size, shuffle=True) # Batches

    for epoch in range(1,n_epochs+1):

        model = model.train()
        train_losses = []

        for batch in DL:

            actual = batch['text'].to(device)

            pred = model(actual)

            l = [criterion(pred[:,i], actual[:,i]) for i in range(69)]

            loss = sum(l)/len(l)



        train_loss = np.mean(train_losses)

        model = model.eval()

        if train_loss < best_loss:
            best_loss = train_loss
            best_model_wts = copy.deepcopy(model.state_dict())
        if epoch % 100 == 0:
            print(f'Epoch {epoch}: CE loss {train_loss}')

I have cuda 11.5 and pytroch 1.10.

Are you able to reproduce this behavior using random data or which dataset are you using?
Could you also post the output of python -c torch.utils.collect_env here, please?

I’m using a custom data set that is essentially two columns - a sequence and a label. I haven’t tried it on any other data set.

Here is the output from python - m torch.utils.collect_env

Collecting environment information...
PyTorch version: 1.10.0
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.9.7 (default, Sep 16 2021, 13:09:58)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.13.0-28-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.5.50
GPU models and configuration: GPU 0: Quadro RTX 4000
Nvidia driver version: 495.29.05
cuDNN version: Probably one of the following:
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] torch==1.10.0
[pip3] torch-utils==0.1.2
[pip3] torchaudio==0.10.0
[pip3] torchtext==0.11.0
[pip3] torchvision==0.11.1
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.3.1               h2bc3f7f_2  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0            py39h7f8727e_0  
[conda] mkl_fft                   1.3.1            py39hd3c417c_0  
[conda] mkl_random                1.2.2            py39h51133e4_0  
[conda] numpy                     1.21.2           py39h20f2e39_0  
[conda] numpy-base                1.21.2           py39h79a1101_0  
[conda] pytorch                   1.10.0          py3.9_cuda11.3_cudnn8.2.0_0    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch-utils               0.1.2                    pypi_0    pypi
[conda] torchaudio                0.10.0               py39_cu113    pytorch
[conda] torchtext                 0.11.0                     py39    pytorch
[conda] torchvision               0.11.1               py39_cu113    pytorch

Thanks so much for looking at this. I appreciate it. I was thinking its how I’m using .to(device). But I’m not getting any errors.