I am basically trying to do mean-variance normalization in embedding space. I have a RNN that embeds a sequence, and then calculates the mean and standard deviation. These are used to normalize the input of another sub-network. This model seems to work (atleast my training and validation losses behave themselves). I now wanted to try to learn weights corresponding to the mean and variance. Sort of like batch-norm (in the most hand-wavy way possible)
Q. Ofcourse, pytorch is magical and my model seems to be training. But am I doing it correctly?
Q. Do I need to do something else with the 2 new weight matrices I have introduced, or will they automatically be added to the parameter list of the RNN class?
Q. What would be the best way to check if these weights are actually learning something? some sort of check for the gradients w.r.t to these parameters?
My code :
class RNN(nn.Module): def __init__(self, input_size, hidden_size, num_layers): super(RNN, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.input_size = input_size self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) #project down to feature dimension self.proj = nn.Linear(hidden_size, 840) #Define weights for the channel mean and variance self.W_mean = nn.Parameter(torch.randn(840,840)) self.W_std = nn.Parameter(torch.randn(840,840)) def forward(self,seq,x): # Set initial states h0 = Variable(torch.zeros(self.num_layers, seq.size(0), self.hidden_size).cuda()) c0 = Variable(torch.zeros(self.num_layers, seq.size(0), self.hidden_size).cuda()) out, _ = self.lstm(seq, (h0,c0)) #project lstm embeddings to feature size out = out.contiguous().view(out.size(0)*out.size(1),self.hidden_size) proj = self.proj(out) proj = proj.view(-1,seq.size(1),x.size(1)) #840 dimensional average embedding avg_emb = torch.mean(proj,1) std_emb = torch.std(proj,dim=1) #subtract the avg embedding from the speech frames avg_emb = avg_emb.view(-1,x.size(1)) std_emb = std_emb.view(-1,x.size(1)) #mean vaiance normalization x_norm = (x - torch.mm(avg_emb,self.W_mean))/torch.mm(std_emb,self.W_std)
To summarize, I am projecting the LSTM output to have the same dimension as the input of my other network i.e. 840. This also corresponds to the size of the W_mean and W_std weight matrices, i.e. (840x840)