Hey everyone,
I am running into a bit of trouble with an undefined reference when creating a custom dataset class using libtorch
. May I ask for a code review to help clarify some things?
here is my data.hpp
:
#pragma once
#include <torch/torch.h>
namespace rock {
namespace data {
namespace datasets {
/// Random dataset.
class RandomDataset : public torch::data::Dataset<RandomDataset> {
public:
/// The mode in which the dataset is loaded.
enum class Mode { kTrain, kTest };
explicit RandomDataset(Mode mode = Mode::kTrain);
/// Returns the `Example` at the given `index`.
torch::data::Example<> get(size_t index) override;
/// Returns the size of the dataset.
torch::optional<size_t> size() const override;
/// Returns true if this is the training subset of MNIST.
bool is_train() const noexcept;
/// Returns all images stacked into a single tensor.
const torch::Tensor& images() const;
/// Returns all targets stacked into a single tensor.
const torch::Tensor& labels() const;
private:
torch::Tensor images_, labels_;
};
} // namespace datasets
} // namespace data
} // namespace rock
And my data.cpp
:
#include <torch/torch.h>
#include "data.hpp"
namespace rock {
namespace data {
namespace datasets {
namespace {
constexpr uint32_t kTrainSize = 60000;
constexpr uint32_t kNumChannels = 3;
constexpr uint32_t kImageRows = 256;
constexpr uint32_t kImageColumns = 256;
// Create random images
torch::Tensor load_images() {
torch::Tensor tensor = torch::randint(
/*low=*/0, /*high=*/255,
{kTrainSize, kImageRows, kImageColumns, kNumChannels}
);
return tensor.to(torch::kFloat32).div_(255);
}
// Create labels - all ones.
torch::Tensor load_labels() {
torch::Tensor tensor = torch::ones({kTrainSize}, torch::kInt);
return tensor.to(torch::kInt64);
}
} // namespace
RandomDataset::RandomDataset(Mode mode)
: images_(load_images()),
labels_(load_labels()) {}
torch::data::Example<> RandomDataset::get(size_t index) {
return {images_[index], labels_[index]};
}
torch::optional<size_t> RandomDataset::size() const {
return images_.size(0);
}
bool RandomDataset::is_train() const noexcept {
return images_.size(0) == kTrainSize;
}
const torch::Tensor& RandomDataset::images() const {
return images_;
}
const torch::Tensor& RandomDataset::labels() const {
return labels_;
}
} // namespace datasets
} // namespace data
} // namespace rock
I saw that @mhubii linked a good reference implementation for a custom dataset here: (libtorch) How to use torch::data::datasets for custom dataset? . Not quite sure where I am getting stuck on this reference.
I was structuring my code to be similar to the mnist.cpp.
I appreciate your help and comments!