There are two ways you can run an LSTM. One is you handle the additional loop, inside of your train loop. The other is you let Pytorch handle it in C++ under the hood. The latter is faster and less code.
Your current code is taking the outputs of the dataloader, which are likely shuffled, and then sending that into the model. Which means your sequential information is likely being lost from the dataloader.
Either way, you need the dataloader to return actual cross sections of your sequences. You want the dataloader to give out something of size (batch_size, sequence_length, features) where sequence length can be any value you choose(ideally a larger value if you are using an LSTM, to take advantage of it’s abilities). Then you would need to send into the model (batch_size, 1, features) with your current setup, where dim = 1 sequentially proceeds through each step. This means you would need a second loop inside the train loop. Alternatively, you could do something like this:
import torch
import torch.nn as nn
import math
device = torch.device("cpu")
class SimpleLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(SimpleLSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x):
self.h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device=device)
self.c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device=device)
out, _ = self.lstm(x, (self.h0, self.c0))
out = self.fc(out)
return out
def get_class_targets(X):
#get sine values
raw_targets = torch.sin(X)
#assign 3 classes with a 0.6, 0.2, 0.2 distribution
targets = torch.zeros_like(raw_targets)
mask1 = raw_targets>math.sin(0.45*math.pi)
mask2 = raw_targets<math.sin(-0.45*math.pi)
targets[mask1] = 1.
targets[mask2] = 2.
return targets
def get_class_accuracy(output, y, a_class):
batch_size = output.size(0)
seq_length = output.size(1)
output = output.reshape(batch_size * seq_length, 3).detach().argmax(dim=1)
y = y.reshape(-1)
output = output[y==a_class]
y_filtered = y[y==a_class]
matching = output ==y_filtered
return torch.sum(matching.to(dtype=torch.long)) / (y_filtered.size(0))
def train_model(model, loss_function, optimizer):
num_batches = 50000
total_loss = 0
model.train()
batch_size = 20
seq_length = 10
for i in range(num_batches): #with a dataloader, you could just create an i = 0 outside the loop and then use i+=1 in the loop
# this just makes a batch of integers with random starting points from 0 to 8, that count 10
X = torch.randint(0, 8, (batch_size,1)).repeat(1, seq_length)+torch.range(0, seq_length-1).unsqueeze(0).repeat(batch_size, 1)
y = get_class_targets(X).to(dtype=torch.long)
output = model(X.unsqueeze(2))
loss = loss_function(output.reshape(batch_size*seq_length, 3), y.reshape(-1)) # only for Cross Entropy
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
accuracy0 = get_class_accuracy(output, y, 0).item()
accuracy1 = get_class_accuracy(output, y, 1).item()
accuracy2 = get_class_accuracy(output, y, 2).item()
print("Loss", loss.item(), "Class0 Accuracy", accuracy0, "Class1 Accuracy", accuracy1, "Class2 Accuracy", accuracy2)
avg_loss = total_loss / num_batches
return avg_loss
model = SimpleLSTM(input_size=1, hidden_size=256, num_layers=1, num_classes=3)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)
weights = 1/(3*torch.tensor((0.9,0.05,0.05)))
print(weights)
criterion = nn.CrossEntropyLoss(weight = weights)
avg_loss = train_model(model, criterion, optimizer)
In the above, I train your model on integers, which then must predict the sine value in the sequence. Note that I send into the model (batch_size, sequence_length, features), and the model returns (batch_size, sequence_length, classes). So this lets the LSTM handle the loop internally.
The above problem also has an unbalanced class distribution of 0.9, 0.05, 0.05. I included an accuracy metric broken down by class so you can see that it is learning.