How to flatten parameters of LSTM layers with spectral normed weights

Looks like directly running layer.flatten_parameters() does not work for LSTM layers if the weights are passed over the spectral norm.

I’ve tried first removing the spectal norm then flattening but it also didn’t work.

self.bilstm = nn.LSTM(in_dim, lstm_channels, 1,
                                    batch_first=True, bidirectional=True)

self.bilstm = spectral_norm(self.bilstm, 'weight_hh_l0')
self.bilstm = spectral_norm(self.bilstm, 'weight_hh_l0_reverse')

...

self.bilstm = remove_spectral_norm(self.bilstm, name="weight_hh_l0")
self.bilstm = remove_spectral_norm(self.bilstm, name="weight_hh_l0_reverse")
self.bilstm.flatten_parameters()

...

self.bilstm.flatten_parameters()
context = self.bilstm(context)[0]

I still receive the warning.

UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters(). (Triggered internally at /opt/conda/conda-bld/pytorch_1666642975993/work/aten/src/ATen/native/cudnn/RNN.cpp:968.)