RuntimeError: Expected object of backend CPU but got backend CUDA for argument #3 'index'

When training a Siamese LSTM model on GPU (it’s working fine on CPU), the following error is occurring:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-19-98abbff14f58> in <module>()
     12     len_strs2 = len_strs2.to(device)
     13 
---> 14     train_one_step(model, optimizer, criterion, train_batch_a, train_batch_b, train_labels, len_strs1, len_strs2)
     15     # test_one_step(model, criterion, train_batch_a, train_batch_b, train_labels, len_strs1, len_strs2)

9 frames
<ipython-input-14-d8c1bfc4d136> in train_one_step(model, optimizer, criterion, train_batch_a, train_batch_b, train_labels, len_strs1, len_strs2)
      2     model.train()
      3 
----> 4     output = model.forward(Variable(train_batch_a), Variable(train_batch_b), len_strs1, len_strs2)
      5     output = output.squeeze()
      6     train_labels = Variable(train_labels).squeeze()

<ipython-input-8-b55412d97ce1> in forward(self, s1, s2, s1_lengths, s2_lengths)
     27         h2, c2 = self.encoder.initHiddenCell(batch_size)
     28 
---> 29         v1, h1, c1 = self.encoder(s1, h1, c1, s1_lengths)
     30         v2, h2, c2 = self.encoder(s2, h2, c2, s2_lengths)
     31 

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

<ipython-input-7-1955254cbea4> in forward(self, input1, hidden, cell, input_lengths)
     26         # input_lengths = torch.as_tensor(input_lengths, dtype=torch.int64, device='cpu')
     27         input1 = torch.nn.utils.rnn.pack_padded_sequence(input1, input_lengths, batch_first=False, enforce_sorted=False)
---> 28         output, (hidden, cell) = self.lstm(input1, (hidden, cell))
     29         return output, hidden, cell

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py in forward(self, input, hx)
    560     def forward(self, input, hx=None):
    561         if isinstance(input, PackedSequence):
--> 562             return self.forward_packed(input, hx)
    563         else:
    564             return self.forward_tensor(input, hx)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py in forward_packed(self, input, hx)
    552         max_batch_size = int(max_batch_size)
    553 
--> 554         output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
    555 
    556         output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py in forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices)
    519             # Each batch of the hidden state should match the input sequence that
    520             # the user believes he/she is passing in.
--> 521             hx = self.permute_hidden(hx, sorted_indices)
    522 
    523         self.check_forward_args(input, hx, batch_sizes)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py in permute_hidden(self, hx, permutation)
    506         if permutation is None:
    507             return hx
--> 508         return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
    509 
    510     def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py in apply_permutation(tensor, permutation, dim)
     19 def apply_permutation(tensor, permutation, dim=1):
     20     # type: (Tensor, Tensor, int) -> Tensor
---> 21     return tensor.index_select(dim, permutation)
     22 
     23 

RuntimeError: Expected object of backend CPU but got backend CUDA for argument #3 'index'

Complete code is on Google colab.
https://colab.research.google.com/drive/1ACehPtimpYJoeo0ZAAMhTtt8Bf5uCQxL

The hidden and cell tensors initialized for LSTM were being created on CPU.

The error message is very unhelpful. The following code removed the error.

class LSTMEncoder(nn.Module):
    def __init__(self, opt):
        super(LSTMEncoder, self).__init__()
        self.embed_size = opt.embedding_dims
        self.hidden_size = opt.hidden_dims
        self.num_layers = opt.num_layers
        self.bidir = opt.lstm_bidir
        self.padding_idx = opt.padding_idx
        if self.bidir:
            self.direction = 2
        else: self.direction = 1
        self.dropout = opt.lstm_dropout
        
        self.lstm = nn.LSTM(input_size=opt.embedding_dims, hidden_size=self.hidden_size, dropout=self.dropout,
                            num_layers=self.num_layers, bidirectional=self.bidir)

    def initHiddenCell(self, batch_size):
        rand_hidden = Variable(torch.zeros(self.direction * self.num_layers, batch_size, self.hidden_size))
        rand_cell = Variable(torch.zeros(self.direction * self.num_layers, batch_size, self.hidden_size))
        return rand_hidden, rand_cell

    def forward(self, input1, hidden, cell):
        # input1 = self.embedding(input1)
        # input_lengths = torch.as_tensor(input_lengths, dtype=torch.int64, device='cpu')
        # input1 = torch.nn.utils.rnn.pack_padded_sequence(input1, input_lengths, batch_first=False, enforce_sorted=False)
        output, (hidden, cell) = self.lstm(input1, (hidden, cell))
        return output, hidden, cell

class Siamese_lstm(nn.Module):
    def __init__(self, opt):
        super(Siamese_lstm, self).__init__()

        self.encoder = LSTMEncoder(opt)

        self.input_dim = int(1 * self.encoder.direction * self.encoder.hidden_size)
        
        self.classifier = nn.Sequential(
            nn.Linear(self.input_dim, int(self.input_dim/2)),
            nn.ReLU(),
            nn.Linear(int(self.input_dim/2), int(self.input_dim/4)),
            nn.ReLU(),
            nn.Linear(int(self.input_dim/4), 1),
            nn.Sigmoid()
        )
        
        self.embedding = nn.Embedding(num_embeddings=opt.vocab_size, embedding_dim=opt.embedding_dims,
                                      padding_idx=opt.padding_idx, max_norm=None, scale_grad_by_freq=False, sparse=False)

    def forward(self, s1, s2, s1_lengths, s2_lengths):
        batch_size = s1.size()[1]
        if device.type == 'cuda':
            max_length = torch.cuda.LongTensor(torch.cat((s1_lengths, s2_lengths))).max().item()
        else:
            max_length = torch.LongTensor(torch.cat((s1_lengths, s2_lengths))).max().item()

        # init hidden, cell
        h1, c1 = self.encoder.initHiddenCell(batch_size)
        h2, c2 = self.encoder.initHiddenCell(batch_size)
        
        s1 = self.embedding(s1)
        s1 = torch.nn.utils.rnn.pack_padded_sequence(s1, s1_lengths, batch_first=False, enforce_sorted=False)
        h1 = h1.to(device)
        c1 = c1.to(device)
        v1, h1, c1 = self.encoder(s1, h1, c1)
        
        s2 = self.embedding(s2)
        s2 = torch.nn.utils.rnn.pack_padded_sequence(s2, s2_lengths, batch_first=False, enforce_sorted=False)
        h2 = h2.to(device)
        c2 = c2.to(device)
        v2, h2, c2 = self.encoder(s2, h2, c2)
        
        v1, l1 = torch.nn.utils.rnn.pad_packed_sequence(v1, batch_first=False, total_length=max_length)
        v2, l2 = torch.nn.utils.rnn.pad_packed_sequence(v2, batch_first=False, total_length=max_length)
        # print(v1)
        if device.type == 'cuda':
            batch_indices = torch.cuda.LongTensor(range(batch_size))
        else:
            batch_indices = torch.LongTensor(range(batch_size))
        v1 = v1[l1-1,batch_indices,:]
        v2 = v2[l2-1,batch_indices,:]
        # features = torch.cat((v1,v2), 1)
        features = abs(v1-v2)
        output = self.classifier(features)

        return output


I’m re implementing code of " Attention-based Context Aware Reasoning for Situation Recognition" . When i execute code for main_ggnn_baseline.py getting the error given above.
Here is link of github repository
https://github.com/thilinicooray/context-aware-reasoning-for-sr