Hello all,
I am basically trying to do mean-variance normalization in embedding space. I have a RNN that embeds a sequence, and then calculates the mean and standard deviation. These are used to normalize the input of another sub-network. This model seems to work (atleast my training and validation losses behave themselves). I now wanted to try to learn weights corresponding to the mean and variance. Sort of like batch-norm (in the most hand-wavy way possible)
Q. Ofcourse, pytorch is magical
and my model seems to be training. But am I doing it correctly?
Q. Do I need to do something else with the 2 new weight matrices I have introduced, or will they automatically be added to the parameter list of the RNN class?
Q. What would be the best way to check if these weights are actually learning something? some sort of check for the gradients w.r.t to these parameters?
My code :
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.input_size = input_size
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
#project down to feature dimension
self.proj = nn.Linear(hidden_size, 840)
#Define weights for the channel mean and variance
self.W_mean = nn.Parameter(torch.randn(840,840))
self.W_std = nn.Parameter(torch.randn(840,840))
def forward(self,seq,x):
# Set initial states
h0 = Variable(torch.zeros(self.num_layers, seq.size(0), self.hidden_size).cuda())
c0 = Variable(torch.zeros(self.num_layers, seq.size(0), self.hidden_size).cuda())
out, _ = self.lstm(seq, (h0,c0))
#project lstm embeddings to feature size
out = out.contiguous().view(out.size(0)*out.size(1),self.hidden_size)
proj = self.proj(out)
proj = proj.view(-1,seq.size(1),x.size(1))
#840 dimensional average embedding
avg_emb = torch.mean(proj,1)
std_emb = torch.std(proj,dim=1)
#subtract the avg embedding from the speech frames
avg_emb = avg_emb.view(-1,x.size(1))
std_emb = std_emb.view(-1,x.size(1))
#mean vaiance normalization
x_norm = (x - torch.mm(avg_emb,self.W_mean))/torch.mm(std_emb,self.W_std)
To summarize, I am projecting the LSTM output to have the same dimension as the input of my other network i.e. 840. This also corresponds to the size of the W_mean and W_std weight matrices, i.e. (840x840)
Thanks,
Gautam