How to use torch.nn.utils.weight_norm in LSTM

I want to apply weight normalization on LSTM.

Is it right if I use this as below ?

import torch.nn.utils.weight_norm as weightNorm

class Sequence(nn.Module):
def __init__(self):
    super(Sequence, self).__init__()
    self.lstm = weightNorm(weightNorm(nn.LSTMCell(input_dim, hidden_size), name = "weight_ih"), name = "weight_hh")
1 Like