Libtorch regression modeling bug

Hi! There is a libtorch bug issue that I can’t find the solution.
I want to make the 3dimension input regression model using libtorch.
But, I stuck in this step.

Anyone who can help me?

Here is the bug message.

(base) myaccount@ubuntu:{mydirectory}$ ./{libtorch_exe_file}
terminate called after throwing an instance of 'c10::Error'
  what():  falseINTERNAL ASSERT FAILED at
 "{mylibtorchdirectory}/include/torch/csrc/api/include/torch/detail/TensorDataContainer.h":299, please report a bug to PyTorch. TensorDataContainer is already a Tensor type, `fill_tensor` should not be called
Exception raised from fill_tensor at {mylibtorchdirectory}/include/torch/csrc/api/include/torch/detail/TensorDataContainer.h:299 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7f019c304a22 in {mylibtorchdirectory}/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x5b (0x7f019c3013db in {mylibtorchdirectory}/lib/libc10.so)
frame #2: c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, char const*) + 0x42 (0x7f019c301942 in {mylibtorchdirectory}/lib/libc10.so)
frame #3: torch::detail::TensorDataContainer::fill_tensor(at::Tensor&) const + 0x28b (0x556357f08961 in ./{libtorch_exe_file})
frame #4: torch::detail::TensorDataContainer::fill_tensor(at::Tensor&) const + 0x235 (0x556357f0890b in ./{libtorch_exe_file})
frame #5: torch::detail::TensorDataContainer::convert_to_tensor(c10::TensorOptions) const + 0x11c (0x556357f05c7a in ./{libtorch_exe_file})
frame #6: torch::tensor(torch::detail::TensorDataContainer, c10::TensorOptions const&) + 0x73 (0x556357f0950a in ./{libtorch_exe_file})
frame #7: MYDataset::get(unsigned long) + 0x119 (0x556357f0c7ad in ./{libtorch_exe_file})
frame #8: torch::data::datasets::Dataset<MYDataset, torch::data::Example<at::Tensor, at::Tensor> >::get_batch(c10::ArrayRef<unsigned long>) + 0xae (0x556357f2e2dc in ./{libtorch_exe_file})
frame #9: torch::data::Example<at::Tensor, at::Tensor> torch::data::datasets::MapDataset<MYDataset, torch::data::transforms::Stack<torch::data::Example<at::Tensor, at::Tensor> > >::get_batch_impl<MYDataset, void>(c10::ArrayRef<unsigned long>) + 0x5f (0x556357f29d43 in ./{libtorch_exe_file})
frame #10: torch::data::datasets::MapDataset<MYDataset, torch::data::transforms::Stack<torch::data::Example<at::Tensor, at::Tensor> > >::get_batch(c10::ArrayRef<unsigned long>) + 0x4d (0x556357f25d97 in ./{libtorch_exe_file})
frame #11: torch::data::DataLoaderBase<torch::data::datasets::MapDataset<MYDataset, torch::data::transforms::Stack<torch::data::Example<at::Tensor, at::Tensor> > >, torch::data::Example<at::Tensor, at::Tensor>, std::vector<unsigned long, std::allocator<unsigned long> > >::next() + 0x20c (0x556357f1f8cc in ./{libtorch_exe_file})
frame #12: torch::data::DataLoaderBase<torch::data::datasets::MapDataset<MYDataset, torch::data::transforms::Stack<torch::data::Example<at::Tensor, at::Tensor> > >, torch::data::Example<at::Tensor, at::Tensor>, std::vector<unsigned long, std::allocator<unsigned long> > >::begin()::{lambda()#1}::operator()() const + 0x35 (0x556357f183cd in ./{libtorch_exe_file})
frame #13: std::_Function_handler<c10::optional<torch::data::Example<at::Tensor, at::Tensor> > (), torch::data::DataLoaderBase<torch::data::datasets::MapDataset<MYDataset, torch::data::transforms::Stack<torch::data::Example<at::Tensor, at::Tensor> > >, torch::data::Example<at::Tensor, at::Tensor>, std::vector<unsigned long, std::allocator<unsigned long> > >::begin()::{lambda()#1}>::_M_invoke(std::_Any_data const&) + 0x3d (0x556357f29f27 in ./{libtorch_exe_file})
frame #14: std::function<c10::optional<torch::data::Example<at::Tensor, at::Tensor> > ()>::operator()() const + 0x4c (0x556357f39f0a in ./{libtorch_exe_file})
frame #15: torch::data::detail::ValidIterator<torch::data::Example<at::Tensor, at::Tensor> >::lazy_initialize() const + 0x41 (0x556357f39e61 in ./{libtorch_exe_file})
frame #16: torch::data::detail::ValidIterator<torch::data::Example<at::Tensor, at::Tensor> >::operator==(torch::data::detail::SentinelIterator<torch::data::Example<at::Tensor, at::Tensor> > const&) const + 0x1c (0x556357f39742 in ./{libtorch_exe_file})
frame #17: torch::data::detail::SentinelIterator<torch::data::Example<at::Tensor, at::Tensor> >::operator==(torch::data::detail::ValidIterator<torch::data::Example<at::Tensor, at::Tensor> > const&) const + 0x2e (0x556357f3955e in ./{libtorch_exe_file})
frame #18: torch::data::detail::ValidIterator<torch::data::Example<at::Tensor, at::Tensor> >::operator==(torch::data::detail::IteratorImpl<torch::data::Example<at::Tensor, at::Tensor> > const&) const + 0x2e (0x556357f3970a in ./{libtorch_exe_file})
frame #19: torch::data::Iterator<torch::data::Example<at::Tensor, at::Tensor> >::operator==(torch::data::Iterator<torch::data::Example<at::Tensor, at::Tensor> > const&) const + 0x41 (0x556357f1fea9 in ./{libtorch_exe_file})
frame #20: torch::data::Iterator<torch::data::Example<at::Tensor, at::Tensor> >::operator!=(torch::data::Iterator<torch::data::Example<at::Tensor, at::Tensor> > const&) const + 0x23 (0x556357f18627 in ./{libtorch_exe_file})
frame #21: main + 0x400 (0x556357f008a9 in ./{libtorch_exe_file})
frame #22: __libc_start_main + 0xe7 (0x7f0145a06bf7 in /lib/x86_64-linux-gnu/libc.so.6)
frame #23: _start + 0x2a (0x556357eff5ca in ./{libtorch_exe_file})

Aborted (core dumped)

and this is CMakeList.txt information.

cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)

find_package(Torch REQUIRED)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
add_executable({libtorch_exe_file} {libtorch_exe_file}.cpp)
target_link_libraries({libtorch_exe_file} "${TORCH_LIBRARIES}")
set_property(TARGET t{libtorch_exe_file} PROPERTY CXX_STANDARD 14)

set(CMAKE_FIND_LIBRARY_SUFFIXES ".a")

Plus, below two blocks are main and header code that I suspect the critical bug within.
main

#include <torch/torch.h>
#include <torch/script.h>
#include <cmath>
#include <iostream>
#include <iomanip>
#include <memory>
#include <fstream>
#include <vector>
#include <time.h>
#include <torch/utils.h>
#include <limits>
#include "csvloader.h"
#include "utils.h"

using namespace std;
using namespace torch;

/*torch Basic model*/
struct BasicModel : nn::Module {
    BasicModel(int64_t input_dim=3, int64_t output_dim=1):
        linear
        ~ (block structure)
    {
        register_module("linear",linear);
        ~
    }
    torch::Tensor forward(torch::Tensor x){
        ~ forward ~
        return x;
    }
    ~variables~;
};

int main(int argc, char** argv)
{
    // torch device checking
    torch::Device device(torch::kCPU);
    if (torch::cuda::is_available()){
        device = torch::Device(torch::kCUDA,1);
    }

    BasicModel model;
    model.to(device);

    // Load csv data
    string path = argc > 1? argv[1] : "{csv_path}.csv";

    torch::data::DataLoaderOptions train_options;
    torch::data::DataLoaderOptions valid_options;
    train_options.batch_size(128);
    valid_options.batch_size(128);
    auto dataset = MYDataset(path).map(torch::data::transforms::Stack<>());
    auto dataloader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(dataset,train_options);
    torch::optim::Adam optimizer(model.parameters(), torch::optim::AdamOptions(0.001));
    float train_loss;

    clock_t tstart = clock();
    for (int64_t epoch=0; epoch<2; epoch++){
        float cnt = 0;
        for (auto& batch : *dataloader){
            //model.zero_grad();
            torch::Tensor x = batch.data.to(device);
            torch::Tensor y = batch.target.to(device);
            auto output = model.forward(x);
            auto loss = torch::nn::functional::mse_loss(output,y);
            train_loss += loss.item<float>();
            optimizer.zero_grad();
            loss.backward();
            optimizer.step();
            cnt += 1.;
        };
        train_loss /= cnt;
        if (epoch % 500 == 0){
            cout << "Epoch : " << epoch << " Train loss: " << train_loss << endl;
        }
    };
    cout << "Time taken: " << (clock()-tstart) << "ms\n" << endl;
    return -1;
};

header

auto ReadCsv(std::string& location){
    std::fstream in(location, std::ios::in);
    std::string line;
    std::string name;
    std::string label;
    std::string dummy;
    float X1;
    float X2;
    float X3;
    float Y;
    std::vector<std::tuple<std::vector<float>,float>> csv;
    std::vector<float> x;
    getline(in,line); // Skipping first row
    while (getline(in,line)) {
        std::stringstream s(line);
        getline(s,dummy,','); 
        getline(s,dummy,','); 
        getline(s,dummy,','); 
        getline(s,dummy,',');
        X1 = stof(dummy);
        getline(s,dummy,',');
        X2 = stof(dummy);
        getline(s,dummy,',');
        X3 = stof(dummy);
        getline(s,dummy,',');
        Y = stof(dummy);
        x.push_back(X1);
        x.push_back(X2);
        x.push_back(X3);
        csv.push_back(std::make_tuple(x,Ids));
        x.clear();
    }
    return csv;
};

struct MYDataset : torch::data::Dataset<MYDataset>
{
    std::vector<std::tuple<std::vector<float> /*x_data*/, float /*y_data*/>> csv_;
    MYDataset(std::string& file_name_csv)
        // Load csv file with file locations and labels.
        : csv_(ReadCsv(file_name_csv)) {
    };

    //override the get method to load custom data.
    torch::data::Example<> get(size_t index) override {
        std::vector<float> x = std::get<0>(csv_[index]);
        float y = std::get<1>(csv_[index]);

        torch::Tensor x_input = torch::tensor({x});
        torch::Tensor y_label = torch::tensor({y});
        return {x_input, y_label};
    };

    // Override the size method to infer the size of the data set.
    torch::optional<size_t> size() const override {
        return csv_.size();
    };
};

So, csv format is like this

x1 x2 x3 y
0.5 0.2 0.3 1

Hi, I tried running your code, there is some variable missing - ‘Ids’ in ‘ReadCsv’. How many columns does your CSV have ? You wrote its x1,x2,x3, y but in your code there is like 3 extra getline s … before parsing x1, x2, …

I skipped to wrote the unwanted column information in csv file.
Actually, true csv format is like this,

Unwanted value1 Unwanted value2 Unwanted value3 X1 X2 X3 Y
dummy dummy dummy 1 2 3 5

You need to change the get method of your dataset (x is already array type).

torch::data::Example<> get(size_t index) override {
   std::vector<float> x = std::get<0>(csv_[index]);
   float y = std::get<1>(csv_[index]);

   torch::Tensor x_input = torch::tensor( x ); // - removed braces here
   torch::Tensor y_label = torch::tensor({ y });

   return { x_input, y_label };
};
1 Like

Thank you! It works!!