Training model with an LSTM layer manually

So I’m trying to implement and train a model with a lstm layer that I’m using to predict the actions of players in a video game. Since in practice the network will be being fed data 1 frame at a time I’m training it in the same way by giving the LSTM layer a batch size and sequence length of 1, and feeding in the sequence manually. The data contains 10 output classes and I’m testing the network by trying to get it to fit 5 samples each of a different class.

The issue is that the model is absolutely abysmal at preforming this. Even after many epochs the model will tend to only output predictions for 2 classes, with one having a clear preference. Even weirder is that the model performs much better when the LSTM layer is removed (~55% testing accuracy Vs ~35%), which given the sequential nature of the data seems particularly perplexing to me and indicates that I’m simply setting up the model incorrectly.

The model currently looks like this:

class SiLU(pt.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * pt.sigmoid(x)

activation = SiLU()

#class to allow specifying the amount and size of layers without hardcoding
class Deep_Layers(pt.nn.Module):
    def __init__(self, Layer_Sizes):
        super(Deep_Layers, self).__init__()
        
        self.Layers = pt.nn.ModuleList([]) #create a new module list
        for i in range(0,len(Layer_Sizes)-1):
            #add the connections between the layer sizes in the tuple
            self.Layers.append(pt.nn.Linear(Layer_Sizes[i],Layer_Sizes[i+1]))

    def forward(self, x):
        for layer in self.Layers:
            #simply feedforward through the layers and apply the activation
            x = activation(layer(x))
        return x

class Network(pt.nn.Module):
    def __init__(self, Deep_Layers_sizes, LSTM_hidden_size):
        super(Network, self).__init__()

        self.lstm = pt.nn.LSTM(Deep_Layers_sizes[0], LSTM_hidden_size)
        self.deep = Deep_Layers(Deep_Layers_sizes)

        self.lstm_size = LSTM_hidden_size

        self.init_hidden()

    def forward(self, x):
        x = x.to(device)

        #pass data through the single lstm layer
        x, self.hidden = self.lstm(x.view(1 ,1, len(x)), self.hidden)
        #final pass through stacked deep layers
        x = self.deep(x)
        return pt.sigmoid(x)

    def init_hidden(self):
        self.hidden = (pt.zeros(1,1,self.lstm_size).to(device), pt.zeros(1,1,self.lstm_size).to(device))

The training data is supplied as sequences of frames of game state information leading up to one of the players performing the action, and each sequence is labelled with the index of the player which performs the action.

My training sequence currently looks like this:

deep_sizes = (678,512,488,256,256,10)
lstm_h_size = 678 #hidden features must be same as input features for some reason

n_epochs = 100
Lr=0.0001

net = Network(deep_sizes,lstm_h_size)
net.to(device)
criterion = pt.nn.CrossEntropyLoss()
optimizer = pt.optim.Adam(net.parameters(), Lr)

losses = []
for e in range(n_epochs):
    avg_loss = 0
    per = 0
    for i in data:
        total_loss = 0
        optimizer.zero_grad()
        #train using a 'sliding window' covering half the data
        for k in range(0, 160):
            #reset the hidden layer to clear history
            net.init_hidden()

            #process the first 160 examples to load in the first 5 seconds of data
            for j in range(k, k+160):
                output = net.forward(i[0][j]))
            #process last frame
            output = net.forward(i[0][k+160])
            #calculate loss from the last frame
            total_loss += criterion(output.view(1,-1),i[1].to(device))
        #backprop and optimize using summed loss for every example in the sequence
        total_loss.backward()
        optimizer.step()
        per += 1
        avg_loss += total_loss.item() / 160
        print(str(per)+": "+str(total_loss.item())+" total loss")
    print("epoch "+str(e)+" average loss: "+str(avg_loss/n_datapoints))
    if (len(losses) > 0):
        print("Change in loss from last epoch: "+str(avg_loss - losses[-1]))
    losses.append(avg_loss/n_datapoints)

To explain the training process, each example contains 320 frames, and I want the network to be able to make predictions within a 160 frame window, so what I’m doing is for each frame in the second half of the data I process the last 160 frames of data without calculating the loss in order to give the model the history and allow it to preform BPTT, and then calculate the loss for the final frame. This is repeated until the entire sequence has had loss calculated, and then backprop is performed using the accumulated loss over the entire training sample.

I’m thinking that since I’m new to using pytorch and RNNs I’m just making a stupid mistake somewhere. Thank you to anyone who can help point out what I’m doing wrong!

Not an accurate answer but you could try a Conv3d network where you can feed the input naturally with N frames

I actually hadn’t considered using a convolutional network! I’ll look into how that performs, thank you!

Just some suggestions from me, and I hope they help :slight_smile:

  1. In this case, where u r feeding 1 frame at a time, u can replace LSTM with LSTMCell, tho it’s not necessary.
  2. U really should add some dropouts inside ur model because both LSTM and Linear suffers from over-fitting badly.