Hello,
I am having some hesitation about how to use the LSTM feedback mechanism while working with batches of data.
I am using LSTMs to classify sequences of frames in order to make predictions about a deformable object that is being moved in a video. I used a CNN to extract visual features from a set of 12 frames (found 12 frames to work reasonably well), then I feed those extracted features to a single-layer LSTM and observe the prediction from the last LSTM output. The model works and I get around 73% which is what I would expect for my data. The LSTM+classifier part is here:
class CNNLSTMModel(nn.Module):
def __init__(self, feature_extractor, n_classes, lstm_hidden_size=128):
super(CNNLSTMModel, self).__init__()
self.feature_extractor = feature_extractor
self.lstm = nn.LSTM(input_size=n_classes, hidden_size=lstm_hidden_size, num_layers=1, batch_first=True)
self.classifier = nn.Linear(128, n_classes)
self.optimizer = optim.Adam(self.parameters(), lr=0.0001)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x):
samples, timesteps, c, h, w = x.size()
c_in = x.view(samples*timesteps, c, h, w)
c_out = self.feature_extractor(c_in)
r_in = c_out.view(samples, timesteps, -1)
# What I am doing
r_out, _ = self.lstm(r_in)
# Save prediction from last LSTM output
classes = self.classifier(r_out[:, -1, :])
softmax = F.log_softmax(classes, dim=1)
return softmax
Now my main concern is that I am not explicitly feeding back the output from the LSTM back into itself. What I feel I should be doing (after reading many examples) is something like this:
r_out, HIDDEN_OUTPUT = self.lstm(r_in, HIDDEN_OUTPUT)
Although my models performs to some degree, I feel that I should handle the hidden state explicitly.
I am not aware of any PyTorch magic that ensures proper handling of the hidden state when using LSTMs like I am (r_out, _ = self.lstm(r_in)
). What I suspect right now is that I am doing a one-to-one classification where instead I want to do a many-to-one (classify a set of frames into a single object instance).
If I were to change my code to explicitly feed the hidden state back, then I run into a dilemma regarding batches. forward(self, x)
receives an x
batch of samples - how do I feed the intra-batch hidden state between subsequent frames in a given sequences? Doing a for-loop inside forward()
seems like a very bad idea.
I hope someone can shed more insight about this.