Hi, I would like to use the output of a CNN model (the output from the layer before softmax, for example), replace the remaining layers with LSTM and train the added LSTM part.
I have read a couple of articles around this topic (such as How can l load my best model as a feature extractor/evaluator?) but unfortunately still struggle to fully understand it.
The output of the CNN model is loaded
model = ShallowFBCSPNet(
in_chans = 21,
n_classes = 2,
input_window_samples=1000,
final_conv_length=25,
)
model.load_state_dict(torch.load('SavedModels/cnnshallow.pth'))
model.eval()
returning
ShallowFBCSPNet(
(ensuredims): Ensure4d()
(dimshuffle): Expression(expression=transpose_time_to_spat)
(conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))
(conv_spat): Conv2d(40, 40, kernel_size=(1, 21), stride=(1, 1), bias=False)
(bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv_nonlin_exp): Expression(expression=square)
(pool): AvgPool2d(kernel_size=(75, 1), stride=(15, 1), padding=0)
(pool_nonlin_exp): Expression(expression=safe_log)
(drop): Dropout(p=0.5, inplace=False)
(conv_classifier): Conv2d(40, 2, kernel_size=(25, 1), stride=(1, 1))
(softmax): LogSoftmax(dim=1)
(squeeze): Expression(expression=squeeze_final_output)
)
I registered a forward hook
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
model.conv_classifier.register_forward_hook(get_activation('conv_classifier'))
x = torch.randn(1, 25)
output = model(x)
print(activation['conv_classifier'])
I would like to freeze the CNN weights, reshape the tensors and put the following LSTM model on top.
class LSTMModel(torch.nn.Module):
def __init__(self):
super(LSTMModel, self).__init__()
self.lstm = torch.nn.LSTM(input_size=6000, hidden_size=21, num_layers=1, batch_first=True)
self.fc = torch.nn.Linear(21, 1)
def forward(self, x):
batches = x.size(0)
h0 = torch.zeros([1, batches, 21])
c0 = torch.zeros([1, batches, 21])
(x, _) = self.lstm(x, (h0, c0))
x = torch.nn.functional.relu(x)
x = self.fc(x)
return x
model = LSTMModel()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters())
Any help would be greatly appreciated!