Flatten trainingsdata in batch

Hi,
I am sorry to bother you but hope you can help me :slight_smile: .
Yesterday I started working with libtorch (I want to increase the speed of a model a bit) and started with a simple toy model. Now I would like to get the MNIST data as a 1D not a 2D tensor. Unfortunatly, I do not understand how the libtorch transformation works (and I didnt find any good documentation) and thus I am not able to make the train data flat.
I am currently using this:

auto data_loader = torch::data::make_data_loader(
            torch::data::datasets::MNIST(R"(...\MNIST\raw)").map(torch::data::transforms::Stack<>()), 64);

to get the training data as {64,1,28,28} but I want it as {64, 784} instead. How can I modify the above to get the result I want?

Best regards
NPC

Hi,
just in case anybody else is a little bit lost:

at::Tensor flatter(at::Tensor tensor_){
    return tensor_[0].flatten();
}
auto data_loader = torch::data::make_data_loader(
            torch::data::datasets::MNIST(PATH).map(
                    torch::data::transforms::TensorLambda<>(flatter)).map(
                    torch::data::transforms::Stack<>()), 64);