I am trying to implement cross validation by running an instance of my LSTM model on different crossvalidation fold datasets. The issue I’m having is that the model is remembering the weights of the parameters with each subsequent run of the cross validation. What is the easiest way to reset the weights of the model so that each cross validation fold starts from some random initial state and is not learning from the previous folds?
Here is my model as currently defined:
class LSTMModel(nn.Module):
'''
LSTM Model parameters
'''
def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
super(LSTMModel, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.num_layers = num_layers
self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers)
self.forecast = nn.Linear(self.hidden_dim, self.output_dim)
def forward(self, x):
batch_size = 1
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).requires_grad_()
c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).requires_grad_()
out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
out = self.forecast(out[:, -1, :])
out = out.unsqueeze_(-1)
return out
def train_model(datafile, date_col, target_col, hidden_dim,
num_layers, output_dim, num_epochs, learning_rate):
dataset = CreateDataset(datafile, date_col, target_col)
input_dim = dataset.num_features
X_train_sets = dataset.X_train_sets
y_train_sets = dataset.y_train_sets
X_test_sets = dataset.X_test_sets
y_test_sets = dataset.y_test_sets
model = LSTMModel(input_dim, hidden_dim, num_layers, output_dim)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
hist = np.zeros(num_epochs)
test_error = []
for X_train, y_train, X_test, y_test in zip(
X_train_sets, y_train_sets, X_test_sets, y_test_sets):
for epoch in range(num_epochs):
optimizer.zero_grad()
output = model.forward(X_train)
loss = criterion(output, y_train)
if epoch % 100 == 0:
print('Epoch ', epoch, 'Loss: ', loss.item())
hist[epoch] = loss.item()
loss.backward()
optimizer.step()