I am having some hesitation about how to use the LSTM feedback mechanism while working with batches of data.
I am using LSTMs to classify sequences of frames in order to make predictions about a deformable object that is being moved in a video. I used a CNN to extract visual features from a set of 12 frames (found 12 frames to work reasonably well), then I feed those extracted features to a single-layer LSTM and observe the prediction from the last LSTM output. The model works and I get around 73% which is what I would expect for my data. The LSTM+classifier part is here:
class CNNLSTMModel(nn.Module): def __init__(self, feature_extractor, n_classes, lstm_hidden_size=128): super(CNNLSTMModel, self).__init__() self.feature_extractor = feature_extractor self.lstm = nn.LSTM(input_size=n_classes, hidden_size=lstm_hidden_size, num_layers=1, batch_first=True) self.classifier = nn.Linear(128, n_classes) self.optimizer = optim.Adam(self.parameters(), lr=0.0001) self.criterion = nn.CrossEntropyLoss() def forward(self, x): samples, timesteps, c, h, w = x.size() c_in = x.view(samples*timesteps, c, h, w) c_out = self.feature_extractor(c_in) r_in = c_out.view(samples, timesteps, -1) # What I am doing r_out, _ = self.lstm(r_in) # Save prediction from last LSTM output classes = self.classifier(r_out[:, -1, :]) softmax = F.log_softmax(classes, dim=1) return softmax
Now my main concern is that I am not explicitly feeding back the output from the LSTM back into itself. What I feel I should be doing (after reading many examples) is something like this:
r_out, HIDDEN_OUTPUT = self.lstm(r_in, HIDDEN_OUTPUT)
Although my models performs to some degree, I feel that I should handle the hidden state explicitly.
I am not aware of any PyTorch magic that ensures proper handling of the hidden state when using LSTMs like I am (
r_out, _ = self.lstm(r_in)). What I suspect right now is that I am doing a one-to-one classification where instead I want to do a many-to-one (classify a set of frames into a single object instance).
If I were to change my code to explicitly feed the hidden state back, then I run into a dilemma regarding batches.
forward(self, x) receives an
x batch of samples - how do I feed the intra-batch hidden state between subsequent frames in a given sequences? Doing a for-loop inside
forward() seems like a very bad idea.
I hope someone can shed more insight about this.