Shrinking one dim to 1 and keeping the last non-zero tensor

Can’t tell if this is the right place to post this. Let me know if I’ve made a misstep in any part of this process:

I’m experimenting with training an LSTM model where the input is batches of sequences of embedded tokens. I created a collate_fn for the dataloader that padded every sequence in the batch with zero vectors to match the longest sequence in the batch.

Last LSTM output is passed to a FC layer and a sigmoid for output.

Model was broken at this point because I was just taking the last LSTM output, which I guess was the output after the model had seen a lot of zeros. So now I’m trying to find the output tensor corresponding to the last non-zero input and I’m using the code below, is there a better way?

class LSTMModel(nn.Module):
    def __init__(self, embedding_dim, output_dim, hidden_dim, num_layers, drop_prob=0.5):
        super().__init__()
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, dropout=drop_prob, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.sigmoid = nn.Sigmoid()
        
        
    def forward(self, X):
        lstm_out, hidden = self.lstm(X)
        
        # shape is (batch_size, max_seq_len, embedding_dim)
        X_reduced = torch.sum(X, dim=2) # shape is (batch_size, max_seq_len)
        non_zero_count = torch.count_nonzero(X_reduced, axis=1) # shape is (batch_size)
        last_non_zero_index = non_zero_count.long() - 1 # convert counts to zero-index
        
        full_first_idx = torch.arange(lstm_out.size(0)) # tried lstm_out[:, last_non_zero_index] but that was giving me a shape of (batch_size, batch_size, embedding_dim)
        last_non_zero_vectors = lstm_out[full_first_idx, last_non_zero_index]
        
        out = self.fc(last_non_zero_vectors)
        out = self.sigmoid(out)
        return out

for input_batch, labels in train_loader:
    test_model = LSTMModel(EMBEDDING_DIM, OUTPUT_DIM, hidden_dim, num_layers, drop_prob)
    test_model = test_model
    test_model(input_batch)