(libtorch) How to use torch::data::datasets for custom dataset?

(Heng Ding) #1

Hi everyone,

Following the https://github.com/goldsborough/examples/blob/cpp/cpp/mnist/mnist.cpp, I am trying to write my own program for training using libtorch.

However, I don’t find any documents about how to load my own dataset.
It seems that the C++ API is similar to Python API.

First, using torch::data::datasets to create a object of dataset.
Second, using torch::data::make_data_loader to create a pointer of loader.

But, I don’t know how to define custom dataset using torch::data::datasets. Does anyone can help me?

(Martin Huber) #2

Hi @Hengd,

well actually it is super easy :slight_smile:. Just as in this example for the MNIST dataset, you can implement a torch::data::datasets::Dataset<Self, SingleExample>. Therefore, you need to override the get(size_t index) method from Dataset. What you need to do, is to get your data from somewhere and convert it into a Tensor, but this is up to you.

#include <torch/torch.h>

// You can for example just read your data and directly store it as tensor.
torch::Tensor read_data(const std::string& loc)
    torch::Tensor tensor = ...

    // Here you need to get your data.

    return tensor;

class MyDataset : public torch::data::Dataset<MyDataset>
        torch::Tensor states_, labels_;

        explicit MyDataset(const std::string& loc_states, const std::string& loc_labels) 
            : states_(read_data(loc_states)),
              labels_(read_data(loc_labels) {   };

        torch::data::Example<> get(size_t index) override;

torch::data::Example<> MyDataset::get(size_t index)
    // You may for example also read in a .csv file that stores locations
    // to your data and then read in the data at this step. Be creative.
    return {states_[index], labels_[index]};

Then, you want to generate a data loader from it, just do

// Generate your data set. At this point you can add transforms to you data set, e.g. stack your
// batches into a single tensor.
auto data_set = MyDataset(loc_states, loc_labels).map(torch::data::transforms::Stack<>());

// Generate a data loader.
auto data_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(

// In a for loop you can now use your data.
for (auto& batch : data_loader) {
    auto data = batch.data;
    auto labels = batch.target;
    // do your usual stuff

Hopefully this helps, although I don’t know the kind of data you are trying to read in.