Fixing the Training Function for an LSTM model built from scratch

So, I’m currently trying to build a custom-built LSTM model. To do so, I’m trying to build the custom LSTM from scratch first and slowly modify it. However, right now I’m having difficultly making the model learn at all.

Here is my code for the my LSTM function.

class LSTM_Model(nn.Module):
    def __init__(self, input_sz, hidden_sz, out_sz, bs):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
        self.fc1 = nn.Linear(hidden_sz, out_sz)
        self.init_weights()

    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, x, init_states=None):
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
                        torch.zeros(bs, self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states

        for t in range(seq_sz):
            x_t = x[:, t, :]
            if x_t.size()[0] != bs:
                fill_in = torch.zeros(bs-x_t[0], seq_sz)
                x_t - torch.cat((x_t, fill_in), 0)

            gates = torch.matmul(x_t, self.W) + torch.matmul(h_t, self.U) + self.bias
            i_t, f_t, g_t, o_t = gates.chunk(4, 1)
            i_t, f_t, g_t, o_t = torch.sigmoid(i_t), torch.sigmoid(f_t), torch.tanh(g_t), torch.sigmoid(i_t)
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t)
        out = self.fc1(h_t)
        return out, (h_t, c_t)

Here is my training function.

def train_model(data_loader, model, loss_function, optimizer):
    num_batches = len(data_loader)
    total_loss = 0
    model.train()
    hidden = None
    #torch.autograd.set_detect_anomaly(True)
    for X, y in data_loader:
        output, hidden = model(X, hidden)
        loss = loss_function(output, y)

        optimizer.zero_grad()
        loss.backward()
        #loss.backward(retain_graph=True)
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / num_batches
    print(f"Train loss: {avg_loss}")
    return avg_loss

Running it gets an initial error message, which requires me to put the retain_graph(True), but after that I get more errors.

I also notice that removing the hidden variable for the training function, the program gives an output, but the model does not learn at all, probably because it’s throwing away the hidden state I think.

Yea, but I have no idea what’s going wrong. Please help.

Detach the hidden states in each iteration or delay the backward call if you want to backpropagate through multiple steps.

Could you explain how to detach the hidden states and what that does? I’ve implemented here at the beginning of the forward function.

class LSTM_Model(nn.Module):
    def __init__(self, input_sz, hidden_sz, out_sz, bs):
        super().__init__()
        self.input_size = input_sz
        self.hidden_size = hidden_sz
        self.batch_size = bs
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
        self.fc1 = nn.Linear(hidden_sz, out_sz)
        self.init_weights()

    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, x, init_states=None):
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
                        torch.zeros(bs, self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states

        h_t = h_t.detach()
        c_t = c_t.detach()

        for t in range(seq_sz):
            x_t = x[:, t, :]
            if x_t.size()[0] < self.batch_size:
                fill_in = torch.zeros(self.batch_size-x_t.size()[0], self.input_size)
                x_t = torch.cat((x_t, fill_in), 0)
            gates = torch.matmul(x_t, self.W) + torch.matmul(h_t, self.U) + self.bias
            i_t, f_t, g_t, o_t = gates.chunk(4, 1)
            i_t, f_t, g_t, o_t = torch.sigmoid(i_t), torch.sigmoid(f_t), torch.tanh(g_t), torch.sigmoid(i_t)
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t)
        out = self.fc1(h_t)
        return out, (h_t, c_t)

However, this does not change performance of the model at all and it still can’t learn anything.

Using the LSTM implementation in Pytorch (just the regular one) leads to errors low as 0.0002 or something like that, but here the model is just refusing to learn.

Since the states are attached to the computation graph calling backward in the second iteration will try to backpropagate through them into the first iterations, which does not have the needed intermediate activations for the gradient computation anymore. Detaching states is a common approach to keep the actual values, but to cut the computation graph, making sure the backward call only computes the gradients for the current iteration.
I don’t know why your model is currently performing worse, but would assume you have detached the states in the PyTorch implementation, too?

I’ve heard this from a few people (that you need to detach the hidden state from the computational graph). However, I’m not sure how to implement it into Pytorch. As far as I’m concerned, you can use the detach function, but I’m not sure how and any implementations I have of it are not working. Could you provide some sample code so that I could take a look at it.

In my implementation, I just add these liens near the beginning of the forward function for the LSTM model.

h_t = h_t.detach()
c_t = c_t.detach()