Help :( Auto encoder predicts shapes 'ok' but lacks colours

A simple autoencoder is learning to decode images, currently this is the output:

Original

image

Predicted

image

Which is not great but I’d assume it will improve.

I’m somewhat concerned that there aren’t colours.

The input to the network (CIFAR10, i.e 3x32x32 images) is transformed as: transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

Those are sent to the network.

Hence plotting the images I do: X/2 + 0.5.

The same for the network output.

The snippet looks like this:

  with torch.no_grad():
    imgs, labels = next(iter(data))
    
    net.eval()
    r = net(imgs)

    imgs = imgs / 2 + 0.5
    r = r / 2 + 0.5
    img_grid = torchvision.utils.make_grid(imgs)
    net_img_grid = torchvision.utils.make_grid(r)
    # print(imgs[0], r[0])
    writer.add_image("images1", img_grid)
    writer.add_image("images2", net_img_grid)
  writer.close()

The transformation made to the coloured images is exactly the same one made to the outputs of the network.

What I see in the printed tensors is:

# image (note it stores both signs.)
tensor([[[ 0.2392,  0.2471,  0.2941,  ...,  0.0745, -0.0118, -0.0902],
         [ 0.1922,  0.1843,  0.2471,  ...,  0.0667, -0.0196, -0.0667],
         [ 0.1843,  0.1843,  0.2392,  ...,  0.0902,  0.0196, -0.0588],
]) 

# prediction: all negative values.
tensor([[[-0.0537, -0.1358, -0.1682,  ..., -0.1811, -0.1710, -0.0690],
         [-0.1587, -0.2982, -0.3134,  ..., -0.2757, -0.2924, -0.1612],
         [-0.2067, -0.3414, -0.3217,  ..., -0.2506, -0.2766, -0.1937],
         ...,])

Where it’s pretty clear that the network is not learning the sign correctly.

I thought this could be due to MSELoss, but I doubt since not predicting the right sign gives a much larger error. ( -.2 -.3 != -.2+.3)

The only reason I can think of as being odd or unsure about is using Drop2d as opposed to just Drop, and I’m unsure which should be best, or whether both should be used.

I’ve read the docs, and check some related posts, but to no avail.

Any ideas to try out?

Hi, could you share the architecture for your model/encoder/decoder? I would think that pytorch would stop you if you tried to input a three-channel image into a one channel input, but perhaps that is your issue and it’s interpreting the image as grayscale.

1 Like

Finally getting to understand the cause. It’s neither the network, nor the normalisation of data. The key seems to be the combination of optimizer and learning rate.

Adam gets to coloured images with small learning rate, SGD doesn’t and instead it collapses to black and white channels.

I wonder if SGD would need a much smaller learning rate.

With just a few epochs, it looks like this, and now the improvement mostly depends on latent vector and number of layers: