Hi,
I’m training a GAN to upscale low resolution image to high resolution image. For my validation set, I want to randomly plot images, but I always get the same one. My dataloader for my validation set has batch size one.
i_val = 0
for _,imgs in enumerate(val_loader):
i_val +=1
imgs_lr = Variable(imgs["lr"].type(Tensor))
imgs_hr = Variable(imgs["hr"].type(Tensor))
# VISUALIZING when for loop is done
if (i_val % len(val_loader) == 0):
with torch.no_grad():
index = torch.tensor([np.random.randint(imgs_lr.shape[0], size=1)[0]]).to(cfg.device)
lr_i = torch.index_select(imgs_lr, 0, index, out=None)