CUDA error: Misaligned Address Depending on Output Size

Hello!

I’ve run into a weird bug using PyTorch on Google Colab’s GPUs when trying to create a simple RNN based Seq2Seq model. More specifically, I’ve run into CUDA error: misaligned address when I make my backward() call. In case it is easier to show, I’ve created short video to describe and showcase the error, and the colab I use in the video can be found here.

The code below is the simplest reproduction of the error that I could get and showcases how the changing the output size can affect whether or not the error occurs. In this example I’ve completely isolated the DecoderRNN and I am passing it random/toy values and am trying to get it to generate a sequence, calculate the loss for that generated sequence, and the pass the loss backward.

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch
import torch.nn as nn

device = torch.device("cuda")

def make_token_tensor(id, vocab_len, should_squeeze=True):
  # print("Making token tensor of id:", id)
  t = torch.zeros(vocab_len).to(device)
  t[id] = 1
  if should_squeeze:
    return t.unsqueeze(0).unsqueeze(0)
  else:
    return t

h_size = 1536 # The Hidden size that goes into the decoder
o_size = 30522 # 30522 = Vocabulary size of default BERT tokenizer

class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.gru = nn.GRU(output_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output, hidden = self.gru(input, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden


decoder = DecoderRNN(h_size, o_size)
decoder.to(device)

# The initial inputs to the decoder
d_hidden = torch.rand((1,1,h_size)).to(device)
prev_token_pred = make_token_tensor(1, o_size) # Has dimensions 1 x 1 x o_size

ans_tokens = [1, 2, 3, 4] # Imagine that in a real model these would be used for teacher forcing
max_len = len(ans_tokens)
seq_preds = []
for i in range(max_len):
  token_pred, d_hidden = decoder(prev_token_pred, d_hidden)
  prev_token_id = torch.argmax(token_pred)
  prev_token_pred = make_token_tensor(prev_token_id, o_size)
  seq_preds.append(token_pred.squeeze(0))
test_preds = torch.stack(seq_preds)

loss = nn.NLLLoss()
input = test_preds
# each element in target has to have 0 <= value < C
target = torch.tensor(ans_tokens).to(device)
output = loss(input, target)
output.backward()

As I show in the video, trying to predict a sequence with length > 1 will produce CUDA error: misaligned address, which then requires a runtime restart before allowing anything else to be done with CUDA. When I first got this error, I switched the device to the cpu and confirmed that it works as expected. Next, I tried predicting a sequence of length 1 and found that the backward() call worked but as soon as I increased the sequence length it began breaking again. Next, I reduced my values for h_size and o_size to 100 each and then began increasing them. At each incrementation (I was basically doubling them and trying out values) the model worked. Values that I attempted include (100,100), (1000, 1000), (1500, 1000), (1536, 1000), (1536, 1022), (1536, 2000), (1536, 4000), (1536, 8000), (1536, 16000), (1536, 32000), (1536, 64000).

To me, this confirmed that the error was not caused due to a lack of available memory. I then tried with (1536, 30000) and (1536, 30500) and both of those options worked. Finally, I attempted with (1536, 30520) and that also worked to generate a prediction for a sequence of length 4.

Along the way, I also found other values for o_size that don’t work, including 30521, 30523, 12345, and others. I’ll admit that while I keep trying to find patterns in these results I have not been able to.

From my understanding, the output size should not have any direct effect on whether or not the model works, and yet it seems that it does. I added the top 2 lines to try to get more information about the error from CUDA but it didn’t tell me anything new. Hopefully the video/colab provide an easy environment to experiment with and see the issue that I’ve been running into.

Thanks for any help and let me know if you have any clarifying questions!

Full disclosure: I am a student so some of things I claimed in the video might be wrong/misstated, and this decoder is one part of a final project for one of my courses.

Thanks for raising this issue. I’m able to reproduce it and will look into it.