I want to apply weight normalization on LSTM.
Is it right if I use this as below ?
import torch.nn.utils.weight_norm as weightNorm
class Sequence(nn.Module):
def __init__(self):
super(Sequence, self).__init__()
self.lstm = weightNorm(weightNorm(nn.LSTMCell(input_dim, hidden_size), name = "weight_ih"), name = "weight_hh")