(libtorch) How to use torch::data::datasets for custom dataset?

(Heng Ding) #1

Hi everyone,

Following the https://github.com/goldsborough/examples/blob/cpp/cpp/mnist/mnist.cpp, I am trying to write my own program for training using libtorch.

However, I don’t find any documents about how to load my own dataset.
It seems that the C++ API is similar to Python API.

First, using torch::data::datasets to create a object of dataset.
Second, using torch::data::make_data_loader to create a pointer of loader.

But, I don’t know how to define custom dataset using torch::data::datasets. Does anyone can help me?

1 Like
(Martin Huber) #2

Hi @Hengd,

well actually it is super easy :slight_smile:. Just as in this example for the MNIST dataset, you can implement a torch::data::datasets::Dataset<Self, SingleExample>. Therefore, you need to override the get(size_t index) method from Dataset. What you need to do, is to get your data from somewhere and convert it into a Tensor, but this is up to you.

#include <torch/torch.h>

// You can for example just read your data and directly store it as tensor.
torch::Tensor read_data(const std::string& loc)
    torch::Tensor tensor = ...

    // Here you need to get your data.

    return tensor;

class MyDataset : public torch::data::Dataset<MyDataset>
        torch::Tensor states_, labels_;

        explicit MyDataset(const std::string& loc_states, const std::string& loc_labels) 
            : states_(read_data(loc_states)),
              labels_(read_data(loc_labels) {   };

        torch::data::Example<> get(size_t index) override;

torch::data::Example<> MyDataset::get(size_t index)
    // You may for example also read in a .csv file that stores locations
    // to your data and then read in the data at this step. Be creative.
    return {states_[index], labels_[index]};

Then, you want to generate a data loader from it, just do

// Generate your data set. At this point you can add transforms to you data set, e.g. stack your
// batches into a single tensor.
auto data_set = MyDataset(loc_states, loc_labels).map(torch::data::transforms::Stack<>());

// Generate a data loader.
auto data_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(

// In a for loop you can now use your data.
for (auto& batch : data_loader) {
    auto data = batch.data;
    auto labels = batch.target;
    // do your usual stuff

Hopefully this helps, although I don’t know the kind of data you are trying to read in.


Message to admins
(dambo) #3

Were you able to come up with a working example that mimics that of the PyTorch-based dataset for reading images?

class GenericDataset(torch.utils.data.Dataset):
  def __init__(self, labels, root_dir, subset=False, transform=None):
    self.labels = labels
    self.root_dir = root_dir
    self.transform = transform

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    img_name = self.labels.iloc[idx, 0]  # file name
    fullname = join(self.root_dir, img_name)
    image = Image.open(fullname).convert('RGB')
    labels = self.labels.iloc[idx, 2]  # category_id
    #         print (labels)
    if self.transform:
      image = self.transform(image)
    return image, labels


Message to admins
(Martin Huber) #4

Hi @dambo,

yes, the above example mimics the PyTorch version of a dataset.

I will implement an example which clarifies it further and post the link here.

(Martin Huber) #5

I have now implemented a little classifier with a custom dataset that classifies apples and bananas. You can find it here