#include <torch/torch.h>
#include <torch/script.h>
#include <iostream>
// Определяем архитектуру сети
struct Net : torch::nn::Module {
torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
Net() {
fc1 = register_module("fc1", torch::nn::Linear(784, 128));
fc2 = register_module("fc2", torch::nn::Linear(128, 64));
fc3 = register_module("fc3", torch::nn::Linear(64, 10));
}
torch::Tensor forward(torch::Tensor x) {
x = x.view({-1, 784}); // Преобразуем входные данные в вектор
x = torch::relu(fc1->forward(x));
x = torch::relu(fc2->forward(x));
x = fc3->forward(x);
return torch::log_softmax(x, 1); // Выход с log_softmax
}
};
int main() {
// Указываем устройство CPU
torch::Device device(torch::kCPU);
// Создаем модель
auto net = std::make_shared<Net>();
net->to(device);
auto data = torch::randn({32, 784}); // 32 изображения
auto target = torch::randint(0, 10, {32}); // Метки
torch::optim::SGD optimizer(net->parameters(), torch::optim::SGDOptions(0.01));
for (int epoch = 0; epoch < 10; epoch++) {
optimizer.zero_grad();
torch::Tensor output = net->forward(data);
torch::Tensor loss = torch::nll_loss(output, target);
loss.backward();
optimizer.step();
std::cout << "Epoch: " << epoch << ", Loss: " << loss.item<float>() << std::endl;
}
std::string model = "neut_model.pth";
torch::save(net, model);
return 0;
}