Use LSTMs for 4D inputs?

Hi, I have a project where my inputs are of shape (batch_size, max_sentences, max_words, embdding_dim). I want to take the last hidden state as the encoded representation of the sentence to end up with a matrix of shape: (batch_size, max_sentences, hidden_size). How can I encode this using an LSTM?

Can it be something like this:

class LSTMEncoder(nn.Module):
    
    def __init__(self, vocab_size, emb_dim, hidden_size, embed):
        super(LSTMEncoder, self).__init__()
        
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.enc_hid_dim = enc_hid_dim
        self.embed = nn.Embedding(self.vocab_size, self.emb_dim)
        self.lstm = nn.LSTMCell(emb_dim, enc_hid_dim)
        
    def forward(self, seq):
        """
        seq: (batch_size, max_sentences, max_words)
        """
        batch_size = seq.size(0)
        final_hidden_states = torch.zeros(batch_size, max_sentences, self.hidden_size).to(device)
        h,c = [torch.zeros(batch_size, self.hidden_size).to(device), torch.zeros(batch_size, self.hidden_size).to(device)]
        
        for step in range(seq.size(1)):
            embedding = self.embed(seq[:,step])     # (batch_size, max_words, emb_dim)
            for i in range(embedding.size(1)):
                h,c = self.lstm(embedding[:, i], (h,c))       # inputs: (batch_size, emb_dim), h: (batch_size, hidden_size)
            final_hidden_states[:, step, i] = h

        
        return final_hidden_states

It would help to know how you are encoding your input. From the RNNs I’ve built in the past, having an input tensor of (batch_size, max_sentences, max_words, embdding_dim) seems redundant. One-hot encoded text would have a shape of (batch_size, seq_len, vocab_size) which can be passed as input to an nn.LSTM layer that had been instantiated with the batch_first parameter set to True. You could also have a label encoded tensor of shape (batch_size, seq_len) where each value in the tensor is an index of your vocabulary. This can be used as input to an nn.Embedding layer which would output a tensor with a shape of (batch_size, seq_len, embedding_dim).

A simple trick I might suggest is to reshape your inputs to (batch_size * num_sentences, max_words, embed_dim), run them through your LSTM, and then you’ll get an output of shape (batch_size * num_sentences, hidden_size) (by taking the last hidden state of the PyTorch nn.LSTM). Then you can easily reshape your matrix to be (batch_size, num_sentences, hidden_size). In this case, you have the last hidden state for each sentence, which is what you want. However, that is the case if your sentences are not connected to each other (such as paragraphs where the sentence depends on the previous one/recursion). If this is your case, I would highly recommend using a Hierarchical LSTM. Look at the PyTorch implementation here for HAN implementation.