Could you explain your data shape a bit?
If you are concatenating these [64, 256, 16, 16]
features, you would end up with a tensor of [N, 64, 256, 16, 16]
.
Would dim0 in this case refer to the batch size?
If so, you could call whole_list = torch.stack(whole_list)
, and pass it to a TensorDataset
as shown in my example.