RuntimeError: "addmm_cuda" not implemented for 'Long'

Hi, I am trying to implement a BiLSTM language model and am getting a strange error.

class BiLSTM(nn.Module):
  def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, tie_weights):
    super().__init__()

    self.num_layers = num_layers
    self.hidden_dim = hidden_dim
    self.embedding_dim = embedding_dim

    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, 
                        dropout=dropout_rate, batch_first=True, bidirectional=True)
    self.dropout = nn.Dropout(dropout_rate)
    self.linear = nn.Linear(hidden_dim*2, vocab_size) # update hidden_dim*2

    if tie_weights:
      # Embedding and hidden layer need to be same size for weight tying
      assert embedding_dim == hidden_dim, 'Cannot tie weights, check dimensions'
      self.linear.weight = self.embedding.weight

    self.init_weights()

    self.hidden = None
    self.cell = None

  def forward(self, x, hidden):
    hidden_0 = Variable(torch.zeros(x.size(0), self.hidden_dim)).to(device)
    cell_0 = Variable(torch.zeros(x.size(0), self.hidden_dim)).to(device)
    output  = self.embedding(x)
    output, (h, c) = self.lstm(x, (hidden_0, cell_0))
    final_state = h[0].view(self.num_layers, 2, batch_size, hidden_dim)[-1]
    output = self.linear(output)
    return output, (h, c)

  def init_weights(self):
    init_range_emb = 0.1
    init_range_other = 1/math.sqrt(self.hidden_dim)
    self.embedding.weight.data.uniform_(-init_range_emb, init_range_emb)
    self.linear.weight.data.uniform_(-init_range_other, init_range_other)
    self.linear.bias.data.zero_()


  def init_hidden(self, batch_size):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_dim)
        cell = torch.zeros(self.num_layers, batch_size, self.hidden_dim)
        return hidden, cell
  
  def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach()
        cell = cell.detach()
        return hidden, cell

vocab_size = len(vocab)
embedding_dim = 100
hidden_dim = 100
num_layers = 5
dropout_rate = 0.4
tie_weights = True
model = BiLSTM(vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, tie_weights)
model.to(device)

import copy
import time

criterion = nn.CrossEntropyLoss()
lr = 20.0  # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def train(model: nn.Module) -> None:
    model.train()  # turn on train mode
    total_loss = 0.
    log_interval = 200
    start_time = time.time()

    hidden = model.init_hidden(batch_size)
    
    num_batches = len(train_data) // bptt
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        hidden = model.detach_hidden(hidden)
        data, targets = get_batch(train_data, i)
        seq_len = data.size(0)
        output, hidden = model(data, hidden)
        loss = criterion(output.view(-1, vocab_size), targets)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            ppl = math.exp(cur_loss)
            print(f'| epoch {epoch:3d} | {batch:5d}/{num_batches:5d} batches | '
                  f'lr {lr:02.2f} | ms/batch {ms_per_batch:5.2f} | '
                  f'loss {cur_loss:5.2f} | ppl {ppl:8.2f}')
            total_loss = 0
            start_time = time.time()

def evaluate(model: nn.Module, eval_data: Tensor) -> float:
    model.eval()  # turn on evaluation mode
    total_loss = 0.
    with torch.no_grad():
        for i in range(0, eval_data.size(0) - 1, bptt):
            data, targets = get_batch(eval_data, i)
            seq_len = data.size(0)
            output = model(data)
            output_flat = output.view(-1, vocab_size)
            total_loss += seq_len * criterion(output_flat, targets).item()
    return total_loss / (len(eval_data) - 1)

best_val_loss = float('inf')
epochs = 50
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(model)
    val_loss = evaluate(model, val_data)
    val_ppl = math.exp(val_loss)
    elapsed = time.time() - epoch_start_time
    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | '
          f'valid loss {val_loss:5.2f} | valid ppl {val_ppl:8.2f}')
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)

    scheduler.step()

Here is the error message

RuntimeError                              Traceback (most recent call last)
<ipython-input-238-453c3f2a9cad> in <cell line: 5>()
      5 for epoch in range(1, epochs + 1):
      6     epoch_start_time = time.time()
----> 7     train(model)
      8     val_loss = evaluate(model, val_data)
      9     val_ppl = math.exp(val_loss)

4 frames
<ipython-input-236-ad74324a3a3a> in train(model)
     20         data, targets = get_batch(train_data, i)
     21         seq_len = data.size(0)
---> 22         output, hidden = model(data, hidden)
     23         loss = criterion(output.view(-1, vocab_size), targets)
     24 

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-234-1a44a39d1437> in forward(self, x, hidden)
     27     cell_0 = Variable(torch.zeros(x.size(0), self.hidden_dim)).to(device)
     28     output  = self.embedding(x)
---> 29     output, (h, c) = self.lstm(x, (hidden_0, cell_0))
     30     final_state = h[0].view(self.num_layers, 2, batch_size, hidden_dim)[-1]
     31     output = self.linear(output)

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
    810         self.check_forward_args(input, hx, batch_sizes)
    811         if batch_sizes is None:
--> 812             result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
    813                               self.dropout, self.training, self.bidirectional, self.batch_first)
    814         else:

RuntimeError: "addmm_cuda" not implemented for 'Long'

I’m not sure what the problem is, any help is appreciated

Based on this code:

    output  = self.embedding(x)
    output, (h, c) = self.lstm(x, (hidden_0, cell_0))

I assume x is a LongTensor containing indices of the vocabulary used to index the self.embedding layer. If so, the self.lstm layer will raise the error since floating point tensors are expected as the input.
I would also guess you might want to use the output tensor as the input to self.lstm instead of the original x input tensor.
Also note that final_state seems to be unused and remove the Variable usage as these are deprecated since PyTorch 0.4.

Thanks for the reply. I adjusted the forward() function

  def forward(self, x, hidden):
    hidden_0 = torch.zeros(x.size(0), self.hidden_dim).to(device)
    cell_0 = torch.zeros(x.size(0), self.hidden_dim).to(device)
    output  = self.embedding(x)
    output, (h, c) = self.lstm(output, (hidden_0, cell_0))
    final_state = h[0].view(self.num_layers, 2, batch_size, hidden_dim)[-1]
    output = self.linear(output)
    return output, (final_state, c)

I’m now getting this error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-289-453c3f2a9cad> in <cell line: 5>()
      5 for epoch in range(1, epochs + 1):
      6     epoch_start_time = time.time()
----> 7     train(model)
      8     val_loss = evaluate(model, val_data)
      9     val_ppl = math.exp(val_loss)

4 frames
<ipython-input-285-ad74324a3a3a> in train(model)
     20         data, targets = get_batch(train_data, i)
     21         seq_len = data.size(0)
---> 22         output, hidden = model(data, hidden)
     23         loss = criterion(output.view(-1, vocab_size), targets)
     24 

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-287-0fc0e88817a4> in forward(self, x, hidden)
     27     cell_0 = torch.zeros(x.size(0), self.hidden_dim).to(device)
     28     output  = self.embedding(x)
---> 29     output, (h, c) = self.lstm(output, (hidden_0, cell_0))
     30     final_state = h[0].view(self.num_layers, 2, batch_size, hidden_dim)[-1]
     31     output = self.linear(output)

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
    796                         msg = ("For batched 3-D input, hx and cx should "
    797                                f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
--> 798                         raise RuntimeError(msg)
    799                 else:
    800                     if hx[0].dim() != 2 or hx[1].dim() != 2:

RuntimeError: For batched 3-D input, hx and cx should also be 3-D but got (2-D, 2-D) tensors

This error points to a wrong shape of the hidden and cell states.
The docs explain the shapes as:

h_0: tensor of shape (D∗num_layers,Hout​) for unbatched input or (D∗num_layers,N,Hout​) containing the initial hidden state for each element in the input sequence.
c_0: tensor of shape (D∗num_layers,Hcell​) for unbatched input or (D∗num_layers,N,Hcell​) containing the initial cell state for each element in the input sequence.

For a batched input these states should thus be 3D tensors while you are initializing them as 2D tensors via:

hidden_0 = Variable(torch.zeros(x.size(0), self.hidden_dim)).to(device)
cell_0 = Variable(torch.zeros(x.size(0), self.hidden_dim)).to(device)

I changed

hidden_0 = Variable(torch.zeros(x.size(0), self.hidden_dim)).to(device)
cell_0 = Variable(torch.zeros(x.size(0), self.hidden_dim)).to(device)

To

hidden_0 = torch.zeros(x.size(0), batch_size, self.hidden_dim).to(device)
cell_0 = torch.zeros(x.size(0), batch_size, self.hidden_dim).to(device)

Where batch size is 10. But I’m still getting errors.

RuntimeError                              Traceback (most recent call last)
<ipython-input-212-453c3f2a9cad> in <cell line: 5>()
      5 for epoch in range(1, epochs + 1):
      6     epoch_start_time = time.time()
----> 7     train(model)
      8     val_loss = evaluate(model, val_data)
      9     val_ppl = math.exp(val_loss)

4 frames
<ipython-input-210-16d7ac1074e2> in train(model)
     20         data, targets = get_batch(train_data, i)
     21         seq_len = data.size(0)
---> 22         output, hidden = model(data, hidden)
     23         loss = criterion(output.view(-1, vocab_size), targets)
     24 

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-207-0793d8a9b34f> in forward(self, x, hidden)
     29     output, (h, c) = self.lstm(output, (hidden_0, cell_0))
     30     output = self.dropout(output)
---> 31     output = self.linear(output)
     32     return output, (h, c)
     33 

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/linear.py in forward(self, input)
    112 
    113     def forward(self, input: Tensor) -> Tensor:
--> 114         return F.linear(input, self.weight, self.bias)
    115 
    116     def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (100x200 and 100x9922)

I think this is something to do with my model dimensions. This is my model currently

class BiLSTM(nn.Module):
  def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, tie_weights):
    super().__init__()

    self.num_layers = num_layers
    self.hidden_dim = hidden_dim
    self.embedding_dim = embedding_dim

    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, 
                        dropout=dropout_rate, batch_first=True, bidirectional=True)
    self.dropout = nn.Dropout(dropout_rate)
    self.linear = nn.Linear(hidden_dim*2, vocab_size) # update hidden_dim*2

hidden_dim and embedding_dim are 100 and num_layers is 5. Vocab size is 9922.

What should the model dimensions be for a BiLSTM?

Your cell state shapes are wrong since (D∗num_layers,N,Hout​) and (D∗num_layers,N,Hcell​) are expected as mentioned above, where N represents the batch size.
In your current examples you are using the batch size twice (in dim0 via x.size(0) and in dim1 via batch_size). Initialize dim0 as 2*num_layers.

Sorry, I’ve correct the cell state shapes to

    hidden_0 = torch.zeros(2*num_layers, batch_size, self.hidden_dim).to(device)
    cell_0 = torch.zeros(2*num_layers, batch_size, self.hidden_dim).to(device)

But I am still getting the same runtime error from above

Full error message

RuntimeError                              Traceback (most recent call last)
<ipython-input-175-453c3f2a9cad> in <cell line: 5>()
      5 for epoch in range(1, epochs + 1):
      6     epoch_start_time = time.time()
----> 7     train(model)
      8     val_loss = evaluate(model, val_data)
      9     val_ppl = math.exp(val_loss)

4 frames
<ipython-input-173-16d7ac1074e2> in train(model)
     20         data, targets = get_batch(train_data, i)
     21         seq_len = data.size(0)
---> 22         output, hidden = model(data, hidden)
     23         loss = criterion(output.view(-1, vocab_size), targets)
     24 

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-171-c83c2df4172c> in forward(self, x, hidden)
     29     output, (h, c) = self.lstm(output, (hidden_0, cell_0))
     30     output = self.dropout(output)
---> 31     output = self.linear(output)
     32     return output, (h, c)
     33 

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/linear.py in forward(self, input)
    112 
    113     def forward(self, input: Tensor) -> Tensor:
--> 114         return F.linear(input, self.weight, self.bias)
    115 
    116     def extra_repr(self) -> str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (100x200 and 100x9922)

The error points towards a shape mismatch in the feature dimension in self.linear. Check its in_features and make sure the output contains the same number of features in its last dimension.
I don’t know which parts you’ve exactly changed so also cannot run the current code.