Seq2seq:index out of range in self

I’m a beginner in pytoch and I try to build seq2seq model that ‘en_ids’ is input (word index) and ‘NPY_DATA’ is output (list of number) as below picture

Following seq2seq model like this below

class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.rnn(embedded)
        return hidden, cell

class Decoder(nn.Module):
    def __init__(self, output_dim, embedding_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(output_dim, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        # input = [batch size]
        # hidden = [n layers * n directions, batch size, hidden dim]
        # cell = [n layers * n directions, batch size, hidden dim]
        # n directions in the decoder will both always be 1, therefore:
        # hidden = [n layers, batch size, hidden dim]
        # context = [n layers, batch size, hidden dim]
        print(f"input shape: {input.shape}")
        print(f"hidden shape: {hidden.shape}")
        print(f"cell shape: {cell.shape}")
        input = input.unsqueeze(0)
        print(f"input shape after unsqueeze: {input.shape}")
        # input = [1, batch size]
        embedded = self.dropout(self.embedding(input))
        print(f"embedded shape: {embedded.shape}")
        #embedded = [1, batch size, embedding dim]
    #    output, (hidden, cell) = self.rnn(input)
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        print(f"output shape: {output.shape}")
        print(f"hidden shape: {hidden.shape}")
        print(f"cell shape: {cell.shape}")
        # output = [seq length, batch size, hidden dim * n directions]
        # hidden = [n layers * n directions, batch size, hidden dim]
        # cell = [n layers * n directions, batch size, hidden dim]
        # seq length and n directions will always be 1 in this decoder, therefore:
        # output = [1, batch size, hidden dim]
        # hidden = [n layers, batch size, hidden dim]
        # cell = [n layers, batch size, hidden dim]
        prediction = self.fc_out(output.squeeze(0))
        print(f"prediction shape: {prediction.shape}")
        print(prediction)
        print(prediction.shape)
        # prediction = [batch size, output dim]
        return prediction, hidden, cell

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        assert (
            encoder.hidden_dim == decoder.hidden_dim
        ), "Hidden dimensions of encoder and decoder must be equal!"
        assert (
            encoder.n_layers == decoder.n_layers
        ), "Encoder and decoder must have equal number of layers!"

    def forward(self, src, trg, teacher_forcing_ratio):
        # src = [src length, batch size]
        # trg = [trg length, batch size]
        # teacher_forcing_ratio is probability to use teacher forcing
        # e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time
        #print(trg.shape[1]) 80
        #print(trg.shape[0]) 8312
        #print(self.decoder.output_dim) 1662
        batch_size = trg.shape[1]
        trg_length = trg.shape[0]
        trg_vocab_size =self.decoder.output_dim
        # tensor to store decoder outputs
        outputs = torch.zeros(trg_length, batch_size, trg_vocab_size).to(self.device)
        # last hidden state of the encoder is used as the initial hidden state of the decoder
        hidden, cell = self.encoder(src)
        # hidden = [n layers * n directions, batch size, hidden dim]
        # cell = [n layers * n directions, batch size, hidden dim]
        # first input to the decoder is the <sos> tokens
        input = trg[0, :]
        # input = [batch size]
        for t in range(1, trg_length):
            # insert input token embedding, previous hidden and previous cell states
            # receive output tensor (predictions) and new hidden and cell states
            output, hidden, cell = self.decoder(input, hidden, cell)
            # output = [batch size, output dim]
            # hidden = [n layers, batch size, hidden dim]
            # cell = [n layers, batch size, hidden dim]
            # place predictions in a tensor holding predictions for each token
            outputs[t] = output
            # decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio
            # get the highest predicted token from our predictions
            top1 = output.argmax(1)
            # if teacher forcing, use actual next token as next input
            # if not, use predicted token
            input = trg[t] if teacher_force else top1
            # input = [batch size]
        return outputs

input_dim =  len(en_vocab)
output_dim = 1662
encoder_embedding_dim = 300
decoder_embedding_dim = 300
hidden_dim = 1024
n_layers = 2
encoder_dropout = 0.5
decoder_dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(
    input_dim,
    encoder_embedding_dim,
    hidden_dim,
    n_layers,
    encoder_dropout,
)

decoder = Decoder(
    output_dim,
    decoder_embedding_dim,
    hidden_dim,
    n_layers,
    decoder_dropout,
)

model = Seq2Seq(encoder, decoder, device).to(device)

def train_fn(
    model, data_loader, optimizer, criterion, clip, teacher_forcing_ratio, device
):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(data_loader):
        src = batch["en_ids"].to(device).long()   # torch.Size([61, 80])
        trg = batch["NPY_DATA"].to(device).long() # torch.Size([8312, 80])
        # src = [src length, batch size]
        # trg = [trg length, batch size]
        optimizer.zero_grad()
        output = model(src, trg, teacher_forcing_ratio)
        # output = [trg length, batch size, trg vocab size]
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        # output = [(trg length - 1) * batch size, trg vocab size]
        trg = trg[1:].view(-1)
        # trg = [(trg length - 1) * batch size]
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

I train the model with code

n_epochs = 10
clip = 1.0
teacher_forcing_ratio = 0.5

best_valid_loss = float("inf")

for epoch in tqdm.tqdm(range(n_epochs)):
    train_loss = train_fn(
        model,
        train_data_loader,
        optimizer,
        criterion,
        clip,
        teacher_forcing_ratio,
        device,
    )
    valid_loss = evaluate_fn(
        model,
        valid_data_loader,
        criterion,
        device,
    )
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), "tut1-model.pt")
    print(f"\tTrain Loss: {train_loss:7.3f} | Train PPL: {np.exp(train_loss):7.3f}")
    print(f"\tValid Loss: {valid_loss:7.3f} | Valid PPL: {np.exp(valid_loss):7.3f}")

However, I got an error IndexError: index out of range in self as below

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[2122], line 8
      5 best_valid_loss = float("inf")
      7 for epoch in tqdm.tqdm(range(n_epochs)):
----> 8     train_loss = train_fn(
      9         model,
     10         train_data_loader,
     11         optimizer,
     12         criterion,
     13         clip,
     14         teacher_forcing_ratio,
     15         device,
     16     )
     17     valid_loss = evaluate_fn(
     18         model,
     19         valid_data_loader,
     20         criterion,
     21         device,
     22     )
     23     if valid_loss < best_valid_loss:

Cell In[2120], line 12
      9 # src = [src length, batch size]
     10 # trg = [trg length, batch size]
...
   2235     # remove once script supports set_grad_enabled
   2236     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2237 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

IndexError: index out of range in self

I try to solve the issue by print dimension each step before getting an error but I still cannot solve it
Could you please suggest or give me advise ?

  0%|          | 0/10 [00:00<?, ?it/s]
input shape: torch.Size([80])
hidden shape: torch.Size([2, 80, 1024])
cell shape: torch.Size([2, 80, 1024])
input shape after unsqueeze: torch.Size([1, 80])
embedded shape: torch.Size([1, 80, 300])
output shape: torch.Size([1, 80, 1024])
hidden shape: torch.Size([2, 80, 1024])
cell shape: torch.Size([2, 80, 1024])
prediction shape: torch.Size([80, 1662])
tensor([[ 0.1316,  0.0260, -0.0131,  ...,  0.0156,  0.0229,  0.0830],
        [ 0.1949,  0.1038, -0.0335,  ...,  0.0473, -0.0440,  0.0506],
        [ 0.1632,  0.0401, -0.0797,  ...,  0.0158,  0.0502,  0.0924],
        ...,
        [ 0.0946,  0.0476, -0.0432,  ..., -0.0108,  0.0349,  0.0327],
        [ 0.1573,  0.0708, -0.0721,  ...,  0.0490,  0.0419,  0.0162],
        [ 0.1679,  0.0769, -0.0531,  ...,  0.0339, -0.0432,  0.0603]],
       grad_fn=<AddmmBackward0>)
torch.Size([80, 1662])
input shape: torch.Size([80])
hidden shape: torch.Size([2, 80, 1024])
cell shape: torch.Size([2, 80, 1024])
input shape after unsqueeze: torch.Size([1, 80])
embedded shape: torch.Size([1, 80, 300])
output shape: torch.Size([1, 80, 1024])
hidden shape: torch.Size([2, 80, 1024])
...
hidden shape: torch.Size([2, 80, 1024])
cell shape: torch.Size([2, 80, 1024])
input shape after unsqueeze: torch.Size([1, 80])
embedded shape: torch.Size([1, 80, 300])
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...
  0%|          | 0/10 [00:02<?, ?it/s]
output shape: torch.Size([1, 80, 1024])
hidden shape: torch.Size([2, 80, 1024])
cell shape: torch.Size([2, 80, 1024])
prediction shape: torch.Size([80, 1662])
tensor([[ 0.1172,  0.1015, -0.0974,  ..., -0.0377,  0.0911, -0.1202],
        [ 0.0309,  0.1227, -0.0685,  ...,  0.0090,  0.0585, -0.0706],
        [ 0.0368,  0.0687, -0.1549,  ..., -0.0036, -0.0263, -0.0953],
        ...,
        [ 0.0988,  0.1312, -0.1015,  ..., -0.0308,  0.0706, -0.1358],
        [-0.0265,  0.1337, -0.0462,  ..., -0.0075,  0.0318, -0.1092],
        [ 0.0909,  0.1659, -0.1577,  ...,  0.0130,  0.0645, -0.1870]],
       grad_fn=<AddmmBackward0>)
torch.Size([80, 1662])
input shape: torch.Size([80])
hidden shape: torch.Size([2, 80, 1024])
cell shape: torch.Size([2, 80, 1024])
input shape after unsqueeze: torch.Size([1, 80])
embedded shape: torch.Size([1, 80, 300])
output shape: torch.Size([1, 80, 1024])
hidden shape: torch.Size([2, 80, 1024])
cell shape: torch.Size([2, 80, 1024])
prediction shape: torch.Size([80, 1662])
tensor([[ 0.1195,  0.0456, -0.1638,  ..., -0.0448,  0.0106, -0.1197],
        [-0.0065,  0.1472, -0.1055,  ...,  0.0439,  0.0639, -0.1022],
        [ 0.0047,  0.0218, -0.1220,  ..., -0.0214, -0.0618, -0.0895],
...
input shape: torch.Size([80])
hidden shape: torch.Size([2, 80, 1024])
cell shape: torch.Size([2, 80, 1024])
input shape after unsqueeze: torch.Size([1, 80])

The dimension is not at fault here, but the min. and max. indices of the index vector.
Print these stats of the input to the nn.Embedding layer and make sure all indices are in the range [0, input_dim-1].

@ptrblck I print max and min values as below

  • train_data[‘NPY_DATA’].min() = -1.3136
  • train_data[‘NPY_DATA’].max() = 1.7036
  • train_data[‘en_ids’].max() = 171
  • train_data[‘en_ids’].min() = 0

If I understand what you explain correctly, train_data[‘NPY_DATA’].min() = -1.3136 is the root cause of error ? because it is out of range nn.Embedding shape is [0, 172] If I misunderstand, could you please tell me.

Parameters:
input_dim = len(en_vocab) = 172
output_dim = 1662
encoder_embedding_dim = 300
decoder_embedding_dim = 300
hidden_dim = 1024
n_layers = 2
encoder_dropout = 0.5
decoder_dropout = 0.5