[solved] Does this LSTM training strategy make sense? (Updating after multiple forwarding)

Hi,

I’m training a LSTM model. Wondering if my training code makes sense.

Since my input sequences are highly diverse in length and I don’t want to pad them. So each time, I just forward on one sample. But I hope to update parameters only after like 32 samples.

So I forward 32 times and store each time’s loss tensor. Then do backward() on all loss. At last, do optimizer step. Does this make sense?

class LSTM(nn.Module):
    def __init__(self, input_channels, lstm_hidden_size=100, lstm_num_layers=2):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_channels,
                            lstm_hidden_size,
                            lstm_num_layers,
                            bias=False,
                            bidirectional=True)
        self.lstm_num_layers = lstm_num_layers
        self.lstm_hidden_size = lstm_hidden_size
        self.rest_hidden_states()
    
    def rest_hidden_states(self):
        self.lstm_states = (
            torch.zeros((self.lstm_num_layers*2, BATCH_SIZE, self.lstm_hidden_size)).to(device),
            torch.zeros((self.lstm_num_layers*2, BATCH_SIZE, self.lstm_hidden_size)).to(device),
        )

    def forward(self, sequence):
        lstm_out, _ = self.lstm(sequence, self.lstm_states)
        return lstm_out[-1]

UPDATE_STEP = 32

def train():
    train_count = 10000
    batch_loss = []
    model = LSTM(20)
    batch_count = 0
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    for epoch in range(500):
        for i, sample in enumerate(data):
            out = model.forward(sample)
            batch_loss.append(loss)
            model.rest_hidden_states()

            if batch_count == UPDATE_STEP or i == train_count - 1:
                # backward
                for l in batch_loss:
                    l.backward()
                optimizer.step()

                # reset
                optimizer.zero_grad()
                batch_count = 0
                batch_loss = []