Newbie. C++. 2 classes classifier - how to make training faster

Hi Guys,
I have data sets consisting of about 100,000 of training examples. Every example has roughly 100 features (normalized between -2 to 2) and 2 labels (1, 0 - good, 0, 1 - bad) which looks something like this:

112 inputs | 2 outputs
-0.58572,-0.0248404,…,-0.926154,1.14224,-0.731815 | 0,1
-1.05784,-0.71582,…,0.989832,1.99402,-1.60195 | 0,1
1.92291,1.79471,…,-0.339433,0.572841,-0.306174 |1,0

This data sets are generated by c++ code and then passed to custom handcrafted NeuralNet. I was asked to rewrite old NN part using libtorch. I used the following NN definition (which more or less follows the old NN structure):

class TorchNet : public torch::nn::Module {
    torch::nn::Linear input{ nullptr };
    torch::nn::Linear hidden1{ nullptr };
    torch::nn::Linear hidden2{ nullptr };
    torch::nn::Linear hidden3{ nullptr };
    torch::nn::Linear hidden4{ nullptr };
    torch::nn::Linear hidden5{ nullptr };
    torch::nn::Linear output{ nullptr };
    int input_size = 0;

public:
    TorchNet(int inputs_count);
    torch::Tensor forward(torch::Tensor x);

    // dataVec - vector of pairs of (input values, targets)
    void train_step(std::vector<std::pair<torch::Tensor, torch::Tensor>> dataVec, torch::optim::Optimizer& optimizer);
    // ...
};

TorchNet::TorchNet(int inputs_count) 
{
    input_size = inputs_count;
    input = register_module("input", torch::nn::Linear(inputs_count, 30));
    hidden1 = register_module("hidden1", torch::nn::Linear(30, 20));
    hidden2 = register_module("hidden2", torch::nn::Linear(20, 15));
    hidden3 = register_module("hidden3", torch::nn::Linear(15, 10));
    hidden4 = register_module("hidden4", torch::nn::Linear(10, 5));
    hidden5 = register_module("hidden5", torch::nn::Linear(5, 3));
    output = register_module("output", torch::nn::Linear(3, 2));
}

torch::Tensor TorchNet::forward(torch::Tensor x) {
    x = torch::tanh(input->forward( x ) );
    x = torch::dropout(x, /*p=*/0.2, /*train=*/is_training());
    x = torch::tanh(hidden1->forward(x));
    x = torch::tanh(hidden2->forward(x));
    x = torch::tanh(hidden3->forward(x));
    x = torch::tanh(hidden4->forward(x));
    x = torch::tanh(hidden5->forward(x));
    x = output->forward(x);
    return torch::log_softmax(x, 0);
}

// dataVec - vector of pairs of (input values, targets)
void TorchNet::train_step(std::vector<std::pair<torch::Tensor, torch::Tensor>> dataVec, torch::optim::Optimizer& optimizer)
{
    train();
    size_t batch_idx = 0;
    double count = 0;
    double max = dataVec.size();
    double pct = 0;

    for (auto& sample : dataVec) {
        auto data = sample.first, targets = sample.second.to(at::kLong);
        optimizer.zero_grad();
        auto output = forward(data);
        // data.sizes() = [112], output.sizes() = [2], targets.sizes() = [2]
        auto loss = torch::nll_loss(output, targets);
        AT_ASSERT(!std::isnan(loss.item<float>()));
        loss.backward();
        optimizer.step();
        if (count / max >= pct) 
        {
            std::cout << "after " << (100.0 * count / max) << "% loss = " << loss.item<double>() << std::endl;
            pct += 0.2;
        }
        ++count;
    }
}

to prepare a training data I do the following:

// data - pair of 2 vectors - first: vector of 112 input values, second: vector of 2 labels 
double TorchNetTest::train(std::vector<std::pair<std::vector<double>, std::vector<double>>> data, ...)
{
...
   auto input_size = data[0].first.size();
   model = std::make_unique< TorchNet>(input_size);
  
    std::vector<std::pair<torch::Tensor, torch::Tensor>> train_vec;
    for (auto data_label : data)
    {
        auto inputs = data_label.first; // a vector of 112 doubles
        auto labels = data_label.second; // a vector of 2 doubles

        auto sample = std::pair<torch::Tensor, torch::Tensor>();
        sample.first = torch::tensor(c10::ArrayRef(inputs));
        sample.second = torch::tensor(c10::ArrayRef(labels));

        train_vec.push_back(sample);        
    }

   torch::optim::SGD optimizer(model->parameters(), 
       torch::optim::SGDOptions(0.01).momentum(0.5));
...
   model->train_step(train_vec, optimizer);
...
}

It works (model can be trained with proper accuracy) but its very slow comparing to old NN. I am using a CPU version of libtorch (old NN also works on CPU only) and I know it could be much faster on GPU but I would like to make it at least usable with CPU only. As I’m a newbie I’m sure my implementation is very far from optimal.
Do you have any suggestions on how to make it more performant?
For example:
Can I rearrange the data to make training step work faster?
Is it possible to perform training with multiple threads?
Is it possible to modify forword() / nll_loss() so it accepts multiple training examples instead of just one?

Thank you in advance.

Your model and loss should accept batches without changes (forwarding tensor of size [N x 112] ). For example - stacking of inputs:

std::vector<double> data1(112);
std::vector<double> data2(112);
std::vector<double> data3(112);

auto input1 = torch::tensor(c10::ArrayRef(data1));
auto input2 = torch::tensor(c10::ArrayRef(data2));
auto input3 = torch::tensor(c10::ArrayRef(data3));

auto input = torch::stack({ input1, input2, input3 });

auto model = std::make_unique<TorchNet>(112);

auto output = model->forward(input); // size [3, 2]

Instead of creating batches in your code, you should probably create a custom dataset with stack transform which will do it for you → example

Hey Matej,

Thank you for your help. Your first suggestion (stacking) worked for forwarding but it failed for loss. I assumed I should do the same stacking for targets as you showed it for inputs.

        auto data_vec = std::vector<torch::Tensor>();
        auto target_vec = std::vector<torch::Tensor>();

        for (auto& data : samples)
        {
            data_vec.push_back(data.first);         // input - torch::tensor(c10::ArrayRef(data_x)); <-[112]
            target_vec.push_back(data.second); // target - torch::tensor(c10::ArrayRef(target_x)); <- [2]
        }

        auto data = torch::stack(data_vec);
        auto targets = torch::stack(target_vec);

        
        auto output = forward(data);   // <- forward accepts stacked inputs
        try {
            auto loss = torch::nll_loss(output, targets); <- exception thrown: 0D or 1D target tensor expected, multi-target not supported

            optimizer.zero_grad();
            loss.backward();
            optimizer.step();
        }
        catch (const std::exception& e) {
            std::cout << "Exception occured\n" << e.what() << std::endl;
            exit(-1);
        }

torch::nll_loss(output, targets) throws “0D or 1D target tensor expected, multi-target not supported”. Should I stack inputs belonging to the same class and pass a single target for such batch?

I will also try you second suggestion (custom dataset) and share the results.

NLLLoss from docs: The target that this loss expects should be a class index in the range [0,C−1] where C = number of classes (instead of one hot vectors). You can use argmax on the stacked tensor (preferably just change the way you load the targets).
Example:

std::vector<double> data1(112);
std::vector<double> data2(112);
std::vector<double> label1 = { 0, 1 };
std::vector<double> label2 = { 1, 0 };

auto inputs = torch::stack({ torch::tensor(data1), torch::tensor(data2) });
auto labels = torch::stack({ torch::tensor(label1), torch::tensor(label2) });

auto model = std::make_unique<TorchNet>(112);

auto output = model->forward(inputs);

labels = torch::argmax(labels, 1);
		
auto loss = torch::nll_loss(output, labels);

Thank you very much for your help. It worked. It also showed how much I have yet to learn :slight_smile: