LSTM always predict same values for all inputs in the batch

Hi everyone, I want to apply LSTM for a regression problem, and for each pixel it needs to predict two values. somehow the LSTM model keeps output same values for all inputs in the batch.

('out: ', tensor([[0.2576, 0.0891],
[0.2576, 0.0891],
[0.2576, 0.0891],
[0.2576, 0.0891],
[0.2576, 0.0891],
[0.2576, 0.0891],
[0.2576, 0.0891],
[0.2576, 0.0891],
[0.2576, 0.0891],
[0.2576, 0.0891]], device=‘cuda:0’, grad_fn=))
('Tmaps: ', tensor([[0.2760, 0.2000],
[0.5120, 0.2000],
[0.8000, 1.0000],
[0.0700, 0.7000],
[0.2580, 0.0400],
[0.0120, 0.0200],
[0.0160, 0.1600],
[0.0200, 0.2000],
[1.0000, 1.0000],
[0.1200, 0.2800]], device=‘cuda:0’))

This is how the model is implemented:

class MRF_RNN(nn.Module):
def init(self, input_dim, hidden_dim, output_dim, n_layers, batch_size):
super(MRF_RNN, self).init()

    self.batch_size = batch_size
    self.n_layers = n_layers
    self.hidden_dim = hidden_dim
    
    self.rnn = nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True)
    self.fc = nn.Linear(hidden_dim, output_dim)

    for m in self.modules():
        if isinstance(m, nn.LSTM):
           for param in m.parameters():
               if len(param.shape) >= 2:
                  nn.init.orthogonal_(param.data)
               else:
                  nn.init.normal_(param.data)
         
def forward(self, x):
    out, (hidden, cell) = self.rnn(x)
    out = self.fc(out[:, -1, :])
    return out

Any inputs would be greatly appreciated!!

1 Like

Hello Yilin_Liu,

did you be any chance figure out what went wrong in your model?

i am facing the exact same problem with an identical model class and obtain the same results regardless of the input.

I would greatly appreciate your help!

Did either of you get to the bottom of this? I get the same issue. Whatever the input the LSTM it returns the same output for each row of the batch. The only difference is the model trains correctly. It is only when it is reloaded and used for predictions this is an issue. Running the LSTM using batch first.

1 Like

Have any of you figured out this issue? I am having a similar issue and made a post describing my particular problem in more detail Outputs of LSTM Model all Very Similar or the Same