Extract layer output from CNN and use it as input for LSTM

Hi, I would like to use the output of a CNN model (the output from the layer before softmax, for example), replace the remaining layers with LSTM and train the added LSTM part.
I have read a couple of articles around this topic (such as How can l load my best model as a feature extractor/evaluator?) but unfortunately still struggle to fully understand it.

The output of the CNN model is loaded

model = ShallowFBCSPNet(
    in_chans = 21,
    n_classes = 2,


  (ensuredims): Ensure4d()
  (dimshuffle): Expression(expression=transpose_time_to_spat) 
  (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))
  (conv_spat): Conv2d(40, 40, kernel_size=(1, 21), stride=(1, 1), bias=False)
  (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_nonlin_exp): Expression(expression=square) 
  (pool): AvgPool2d(kernel_size=(75, 1), stride=(15, 1), padding=0)
  (pool_nonlin_exp): Expression(expression=safe_log) 
  (drop): Dropout(p=0.5, inplace=False)
  (conv_classifier): Conv2d(40, 2, kernel_size=(25, 1), stride=(1, 1))
  (softmax): LogSoftmax(dim=1)
  (squeeze): Expression(expression=squeeze_final_output) 

I registered a forward hook

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

x = torch.randn(1, 25)
output = model(x)

I would like to freeze the CNN weights, reshape the tensors and put the following LSTM model on top.

class LSTMModel(torch.nn.Module):
  def __init__(self):
    super(LSTMModel, self).__init__()
    self.lstm = torch.nn.LSTM(input_size=6000, hidden_size=21, num_layers=1, batch_first=True)
    self.fc = torch.nn.Linear(21, 1)

  def forward(self, x):
    batches = x.size(0)
    h0 = torch.zeros([1, batches, 21])
    c0 = torch.zeros([1, batches, 21])
    (x, _) = self.lstm(x, (h0, c0))
    x = torch.nn.functional.relu(x)
    x = self.fc(x)
    return x

model = LSTMModel()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters())

Any help would be greatly appreciated!