I am trying to implementing the following two toy codes:
lstm = nn.LSTM(512, 512, 1, True, True, 0, False)
input = torch.rand(128, 260, 512)
output, hidden = lstm(input)
and
lstm_cell=nn.LSTMCell(512, 512, True)
input = torch.rand(128, 260, 512).unbind(1)
h=c=torch.zeros(128, 512)
for input_step in input:
h,c=lstm_cell(input_step, (h, c))
The when the sequence lengths is set to, for example, 260
, the second implementation consumes far more GPU memory than the first one.
Anyone know the reason? How can I optimize it? I need to add some conditions inside the for loop of the second implementation.
I also found similar post way back but nobody has given good solutions.