Hi, I have a small dataset: 15 pairs (3D vector, double/scalar) that I want to train a Neural Network on using the PyTorch C++ API. I wrote a training loop and it compiles + runs, but convergence is exactly 0: the max MSE error on the input data stays exactly the same. Here is the code:
// - Networks are stored in cut cells for re-use.
auto model = torch::nn::Sequential(
torch::nn::Linear(3,16), // FIXME: Abstract depth.
torch::nn::Tanh(), // FIXME: Abstract act. func.
torch::nn::Linear(16,1)
);
// - Minimize the MLSE error at stencil points.
torch::optim::Adam optimizer(
model->parameters(),
0.001 // step
);
torch::Tensor training_prediction =
torch::zeros_like(values);
torch::Tensor loss_values =
torch::zeros_like(values);
for (size_t epoch = 1; epoch <= MAX_ITERATIONS; ++epoch)
{
optimizer.zero_grad();
training_prediction = model->forward(points);
loss_values =
torch::mse_loss(training_prediction, values);
optimizer.step();
#ifdef NN_DEBUG
std::cout << "Epoch = " << epoch
<< "\n max stencil loss = "
<< loss_values.max().item<double>()
<< std::endl;
#endif
if (loss_values.max().item<double>() < EPSILON)
break;
}
points
are defined above as well as values
and it’s kind of involved, I can post a link to the code if that helps. What I would like to know is if there is anything wrong with the training loop?
Here is the (non-converging) output:
Epoch = 1
max stencil loss = 0.00478412
Epoch = 2
max stencil loss = 0.00478412
Epoch = 3
max stencil loss = 0.00478412
Epoch = 4
max stencil loss = 0.00478412
Epoch = 5
max stencil loss = 0.00478412
Epoch = 6
max stencil loss = 0.00478412
Epoch = 7
max stencil loss = 0.00478412
Epoch = 8
max stencil loss = 0.00478412
Epoch = 9
max stencil loss = 0.00478412
Epoch = 10
max stencil loss = 0.00478412
Epoch = 11
max stencil loss = 0.00478412
Epoch = 12
max stencil loss = 0.00478412
Epoch = 13
max stencil loss = 0.00478412
...
Why does the max(loss_mse) stay constant? I understand that the amount of data is insufficient for training (15 entries), but I expected some convergence at least… there must be something wrong with the training loop…