Why and How to flatten lstm parameters?

Hi I have only one use of LSTM in my code:

class DecoderRNN(nn.Module):
    def __init__(self, 
                 embed_size, 
                 hidden_size, 
                 output_size,
                 dropout_rate,
                 num_layers):

        super(DecoderRNN, self).__init__()

        self.hidden_size = hidden_size 
        self.embed_size = embed_size
        self.output_size = output_size
        self.dropout_rate = dropout_rate
        self.num_layers = num_layers

        self.embedding = nn.Embedding(self.output_size, self.embed_size)

        self.LSTM = nn.LSTM(1112, self.hidden_size, num_layers=self.num_layers, dropout=self.dropout_rate, batch_first=True)
        self.dropout = nn.Dropout(self.dropout_rate)
        self.out = nn.Linear(self.hidden_size, self.output_size)

        self.visual_bn = nn.BatchNorm1d(self.embed_size, momentum=0.01)
        self.visual_linear = nn.Linear(self.embed_size, self.hidden_size)

        self.output_linear = nn.Linear(self.hidden_size, self.hidden_size)

    def forward(self, image_encodings, encoder_outputs, hidden, output, prev_context, flag=0, output_lengths=None):

        saved_image_encodings = image_encodings
        image_encodings = self.visual_bn(image_encodings)
        image_encodings = image_encodings.unsqueeze(1)

        prev_context = torch.cat((image_encodings, prev_context), 2)

        embedded = self.embedding(output)
        rnn_input = embedded

        new_input = torch.cat((rnn_input, prev_context), dim=2)

        if hidden[0].shape[2] == 300:
            hidden = self.transformer_to_lstm_layer(hidden[0]).permute((1, 0, 2)) , self.transformer_to_lstm_layer(hidden[1]).permute((1, 0, 2))

        output, hidden = self.LSTM(new_input, hidden)
        #self.LSTM.flatten_parameters()

        if self.num_layers > 1:
            cur_hidden = hidden[0][-1:]
        else:
            cur_hidden = hidden[0]

        output_final = F.tanh(self.output_linear(output))
        output_final = self.dropout(output_final)
        output = self.out(output_final)

        return output, hidden, cur_hidden.permute((1, 0, 2))

I still get the below warnings even when I uncomment self.LSTM.flatten_parameters() and the code is just too slow:

Where to call flatten parameters here?

This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().
/opt/conda/conda-bld/pytorch_1565272279342/work/aten/src/ATen/native/cudnn/RNN.cpp:1266: UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().

any hints on fixing the issue? Also, why only lstm parameters need to be flattened? Why not linear layers,

1 Like

this is linked to the post here: Bug in Data Parallel?

put self.LSTM.flatten_parameters() before the lstm call

7 Likes

@Rafael_R Do you solve the problem I still got the warning

I am having the same problem, however I have already trained my model with this warning and saved the model.
During inference (testing) I am getting this warning plus OOM error. I wonder if I can solve this for using an already trained model?

Thanks