Your image tensors are missing the channel dimension.
I would assume Image.open
would return an image with a channel dimension, even if it’s a grayscale image.
However, the workaround would be to unsqueeze
the channel dimension before returning the tensor:
x = x.unsqueeze(1)
return x, y