I’ve been staring at this simple script and don’t know why it isn’t training. The loss is does not budge.
#include <torch/torch.h>
using namespace torch::indexing;
torch::Device device(torch::kCUDA);
struct Critic_Net : torch::nn::Module {
torch::Tensor next_state_batch__sampled_action;
public:
Critic_Net() {
// Construct and register two Linear submodules.
lin1 = torch::nn::Linear(427, 42);
lin2 = torch::nn::Linear(42, 286);
lin3 = torch::nn::Linear(286, 42);
lin1->to(device);
lin2->to(device);
lin3->to(device);
}
torch::Tensor forward(torch::Tensor next_state_batch__sampled_action) {
auto h = next_state_batch__sampled_action;
h = torch::relu(lin1->forward(h));
h = torch::tanh(lin2->forward(h));
h = lin3->forward(h);
return torch::nan_to_num(h);
}
torch::nn::Linear lin1{nullptr}, lin2{nullptr}, lin3{nullptr};
};
auto one = torch::rand({42, 427}).to(device);
auto target = torch::rand({42, 42}).to(device);
torch::Tensor y_hat;
at::Tensor loss;
auto net = Critic_Net();
auto net_optimizer = torch::optim::Adam(net.parameters(), 1e-3);
int main() {
for (int e = 0; e<2000; e++) {
net_optimizer.zero_grad();
y_hat = net.forward(one);
loss = torch::mse_loss(target, y_hat);
loss.backward();
net_optimizer.step();
if (e % 50 == 0) {
std::cout << loss.item() << " " << "\n";
}
}
}