LSTM with layer/batch normalization

I’m trying to implement a LSTM with layer normalization but I’m getting an error when I run loss.backward(). If I remove the LayerNormalizations that I’ve created it runs fine. I guess that I didn’t set up Layer Normalization correctly but I’m still new to PyTorch so any help would be appreciated!

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-28-9c20ee070d0a> in <module>()
     17 
     18 for epoch in range(1, n_epochs + 1):
---> 19     loss = train(*random_training_set())
     20     loss_avg += loss
     21 

<ipython-input-17-d72503b5d7db> in train(inp, target)
      9         loss += criterion(output, target[c])
     10 
---> 11     loss.backward()
     12     decoder_optimizer.step()
     13 

/usr/local/lib/python2.7/dist-packages/torch/autograd/variable.pyc in backward(self, gradient, retain_variables)
    144                     'or with gradient w.r.t. the variable')
    145             gradient = self.data.new().resize_as_(self.data).fill_(1)
--> 146         self._execution_engine.run_backward((self,), (gradient,), retain_variables)
    147 
    148     def register_hook(self, hook):

/usr/local/lib/python2.7/dist-packages/torch/autograd/_functions/reduce.pyc in backward(self, grad_output)
     95             grad_input_val = grad_output[0]
     96             grad_input_val /= reduce(lambda x, y: x * y, self.input_size, 1)
---> 97             return grad_output.new(*self.input_size).fill_(grad_input_val)
     98         else:
     99             repeats = [1 for _ in self.input_size]

TypeError: fill_ received an invalid combination of arguments - got (torch.cuda.FloatTensor), but expected (float value)

Here is my code:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter

class LayerNormalization(nn.Module):
    def __init__(self, hidden_size, eps=1e-5):
        super(LayerNormalization, self).__init__()
        
        self.eps = eps
        self.hidden_size = hidden_size
        self.a2 = nn.Parameter(torch.ones(1, hidden_size), requires_grad=True)
        self.b2 = nn.Parameter(torch.zeros(1, hidden_size), requires_grad=True)
        
    def forward(self, z):
        mu = torch.mean(z)
        sigma = torch.std(z)

        ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
        ln_out = ln_out * self.a2 + self.b2
        return ln_out
    
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, embed_size, output_size):
        super(LSTM, self).__init__()

        self.hidden_size = hidden_size
        # input embedding
        self.encoder = nn.Embedding(input_size, embed_size)
        # lstm weights
        self.weight_fh = nn.Linear(hidden_size, hidden_size)
        self.weight_ih = nn.Linear(hidden_size, hidden_size)
        self.weight_ch = nn.Linear(hidden_size, hidden_size)
        self.weight_oh = nn.Linear(hidden_size, hidden_size)
        self.weight_fx = nn.Linear(embed_size, hidden_size)
        self.weight_ix = nn.Linear(embed_size, hidden_size)
        self.weight_cx = nn.Linear(embed_size, hidden_size)
        self.weight_ox = nn.Linear(embed_size, hidden_size)
        # decoder
        self.decoder = nn.Linear(hidden_size, output_size)
        # layer normalization
        self.lnx = LayerNormalization(hidden_size)
        self.lnh = LayerNormalization(hidden_size)
        self.lnc = LayerNormalization(hidden_size)

    def forward(self, inp, h_0, c_0):
        # encode the input characters
        inp = self.encoder(inp)
        # forget gate
        f_g = F.sigmoid(self.lnx(self.weight_fx(inp)) + self.lnh(self.weight_fh(h_0)))
        # input gate
        i_g = F.sigmoid(self.lnx(self.weight_ix(inp)) + self.lnh(self.weight_ih(h_0)))
        # intermediate cell state
        c_tilda = F.tanh(self.lnx(self.weight_cx(inp)) + self.lnh(self.weight_ch(h_0)))
        # current cell state
        cx = f_g * c_0 + i_g * c_tilda
        # output gate
        o_g = F.sigmoid(self.lnx(self.weight_ox(inp)) + self.lnh(self.weight_oh(h_0)))
        # hidden state
        hx = o_g * F.tanh(self.lnc(cx))

        out = self.decoder(hx.view(1,-1))

        return out, hx, cx

    def init_hidden(self):
        h_0 = Variable(torch.zeros(1, self.hidden_size)).cuda()
        c_0 = Variable(torch.zeros(1, self.hidden_size)).cuda()
        return h_0, c_0

Can you post a complete script? It’s hard to tell what went wrong just from your snippet.

1 Like

I just replaced the GRU in this notebook with the code that I posted above. I also had to pass hx and cx instead of just ‘hidden’ and fix output view in the loss criterion.

1 Like

I tried your LayerNormalization Class. It worked by changing the line :

ln_out = ln_out * self.a2 + self.b2

to :

ln_out = ln_out * self.a2.expand_as(z) + self.b2.expand_as(z)

If I’m not mistaken this does the same as nn.InstanceNormalization with affine=True, if you want to go with the stock classes.

Best regards

Thomas

1 Like

Is nn.InstanceNorm a new addition?

For pytorch’s pace, a month isn’t all that new. Relative to your original question, it’s not all that old. :slight_smile: Kudos to the people doing all the stuff.

Best regards

Thomas

As tom already pointed out, you are doing batch normalization, not layer normalization! Maybe you could change the topic name to reflect that?

I don’t think that I am doing batch normalization. If you look at the supplementary material in the layer normalization paper my equations match theirs.

https://arxiv.org/abs/1607.06450

Layer normalization uses all the activations per instance from the batch for normalization and batch normalization uses the whole batch for each activations. Ok, but you didn’t normalize per neuron, so it was a mix of both. So we were both right and wrong. :slight_smile: (sorry for the confusion)

When I didn’t miss something you should use

def forward(self, z):
        mu = torch.mean(z, dim=1)
        sigma = torch.std(z, dim=1)

if your input is (B,N) B batch size and N input size, if you want to do layer normalization.

2 Likes