A more general approach could use torch.split
and a better .scatter_
call:
nb_cat = 3 # Number of images to concatenate
for i, data in enumerate(trainloader, 0):
inputs, labels = data
batch_size = inputs.size(0)
split_size = batch_size // nb_cat
inputs = torch.cat(inputs.split(split_size), 2)
labels = torch.stack(labels.split(nb_cat))
one_hot_labels = torch.zeros(inputs.size(0), nb_classes, dtype=torch.long).scatter_(1, labels, 1)
one_hot_labels = one_hot_labels.float()
EDIT: Always assuming the batch size if divisible by nb_cat
without a remainder.
I haven’t tested edge cases!