Create a dataset from a GAN

Hi everyone,

I am currently playing around with GANs and build a simple one for CIFAR10.
For further use, e.g. as an evaluation set for a CNN, I wanted to save some generated images, but I couldn’t figure out how I would create a dataset and save them for later loading into a DataLoader. Already gave the answer myself :man_facepalming:

For simplicity let’s say I know in advance which class is generated (0-9) and tell my generator to generate 100 airplane images which I want to save.

So im my example the generator would take the random noise and generate the specified amount of images. But at that point I am stuck, I though about something like below, but this didn’t seem to work as I got some size mismatch error. Derp, see below :man_facepalming:

gen_imgs = generator(Variable(torch.randn(100, 100, 1, 1)))
gen_data = data_utils.TensorDataset(gen_imgs, torch.from_numpy(np.array([0]*100)))

I would be happy, if someone could fill me in, on how I could realise something like that.


I am not sure if I completely understood your question because more or less you wrote the solution. I don’t know the size of your gen_imgs. But for the sake of simplicity assume I assume the following:

# Assume a batch dimension of 128, 3 channels and 32 pixel width and height
gen_imgs = torch.randn(128, 3, 32, 32)
# You Y has to match the batch dimension
dataset = data_utils.TensorDataset(gen_imgs, torch.from_numpy(np.array([0]*128)))
dataloader = data_utils.DataLoader(dataset, batch_size=8)

for x, y in dataloader:
    # do something

Or do you want to know how to write the result of your generator on disk?

You are totally right, and it was 100% my fault for not reading the error message right, I though I got some way different tensor, but actually it was the right one only with the fact, that I resized the CIFAR images for the GAN to 64x64 thus resulting a [100 x 16384] vector and my CNN was expecting 4096 (for 32x32).
So a quick change to TensorDataset(F.interpolate(gen_imgs, 32), torch.from_numpy(np.array([0] * 100))) and everything was working.
So thanks for pointing that out!