I have implemented a simple MLP to train on a model. I’m using the “ignite” wrapper to simplify the process. However, the loss is not decreasing nor increasing. The code I’m using for the training is as follows:
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.layers = nn.Sequential(nn.Linear(2048, 1024),
nn.Tanh(),
nn.Linear(1024, 978))
def forward(self, x):
x = self.layers(x)
return x
model = MLP().double()
optimizer = torch.optim.Adam(model.parameters())
loss = torch.nn.MSELoss()
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
trainer = create_supervised_trainer(model = model, optimizer=optimizer, loss_fn=loss, device=device)
evaluator = create_supervised_evaluator(model, metrics={"MSE": MeanSquaredError(), "MAE" : MeanAbsoluteError()}, device=device)
tensor_x_train = torch.from_numpy(X_train)
tensor_y_train = torch.from_numpy(y_train)
train_dataset = utils.TensorDataset(tensor_x_train, tensor_y_train)
train_loader = utils.DataLoader(train_dataset)
tensor_x_test = torch.from_numpy(X_test)
tensor_y_test = torch.from_numpy(y_test)
test_dataset = utils.TensorDataset(tensor_x_test, tensor_y_test)
val_loader = utils.DataLoader(test_dataset)
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
print("Epoch[{}] Loss: {:.2f}".format(engine.state.epoch, len(train_loader), engine.state.output))
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
evaluator.run(train_loader)
metrics = evaluator.state.metrics
print("Training Results - Epoch: {} Avg MSE: {:.2f} Avg MAE: {:.2f}"
.format(trainer.state.epoch, metrics['MSE'], metrics['MAE']))
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
print("Validation Results - Epoch: {} Avg MSE: {:.2f} Avg MAE: {:.2f}"
.format(engine.state.epoch, metrics['MSE'], metrics['MAE']))
trainer.run(train_loader, max_epochs=100)
I did checked the grads of the parameters and they are non-zero values. However, the loss is not changing. Where did I gone wrong?
Thanks in advance.