Concatenating images

A more general approach could use torch.split and a better .scatter_ call:

nb_cat = 3  # Number of images to concatenate

for i, data in enumerate(trainloader, 0):
    inputs, labels = data
    batch_size = inputs.size(0)
    split_size = batch_size // nb_cat
    inputs = torch.cat(inputs.split(split_size), 2)

    labels = torch.stack(labels.split(nb_cat))
    one_hot_labels = torch.zeros(inputs.size(0), nb_classes, dtype=torch.long).scatter_(1, labels, 1)
    one_hot_labels = one_hot_labels.float()

EDIT: Always assuming the batch size if divisible by nb_cat without a remainder.
I haven’t tested edge cases!