Using a simple loop, I noticed that the memory slowly increased. This over time will lead to crash. I tried adding detach(), tried adding no grad guards, nothing seems to help.
while(true){
auto Q_val = dqn_module.forward(torch::zeros({10, 4}));
auto Q_val_max = Q_val.max_values(1, true);
}
dqn_module is simple:
seq = Sequential(
Linear(g.num_observations, 64),
Functional(relu),
Linear(64, 32),
Functional(relu),
Linear(32, g.num_actions)
);