I don’t have too much background in machine learning, but I am currently working on upscaling electron microscopy images using a GAN architecture. I have wrapped the models in nn.DataParallel and during training everything works as expected.
However, when applying the obtained weights during inference the results differ greatly from the periodically generated output during training, and I cannot figure out why this is. Inference using weights obtained when not calling DataParallel during training are as expected.
Below is the (relevant) code:
cuda = torch.cuda.is_available()
hr_shape = (opt.hr_height, opt.hr_width)
# Initialize generator and discriminator
generator = GeneratorResNet(factor=opt.factor)
discriminator = Discriminator(input_shape=(opt.channels, *hr_shape))
feature_extractor = FeatureExtractor()
# Set feature extractor to inference mode
feature_extractor.eval()
# Losses
criterion_GAN = torch.nn.MSELoss()
temp_err = torch.nn.MSELoss()
criterion_content = torch.nn.L1Loss()
if torch.cuda.device_count() > 1 and opt.parallel == 1:
print("Multiple GPU training")
temp_err = nn.DataParallel(temp_err)
generator = nn.DataParallel(generator)
discriminator = nn.DataParallel(discriminator)
feature_extractor = nn.DataParallel(feature_extractor)
criterion_GAN = nn.DataParallel(criterion_GAN)
criterion_content = nn.DataParallel(criterion_content)
if cuda:
temp_err = temp_err.cuda()
generator = generator.cuda()
discriminator = discriminator.cuda()
feature_extractor = feature_extractor.cuda()
criterion_GAN = criterion_GAN.cuda()
criterion_content = criterion_content.cuda()
if opt.epoch != 0:
inp = input("Saved models folder: ")
#inp2 = input("Saved images folder: ")
timestamp = inp.split("_s")[0]
saved_models = os.path.join(os.getcwd(), inp)
imgs_out = os.path.join(os.getcwd(), timestamp + "_images")
# Load pretrained models
generator.load_state_dict(torch.load(f"{saved_models}/generator_%d.pth" % opt.epoch))
discriminator.load_state_dict(torch.load(f"{saved_models}/discriminator_%d.pth" % opt.epoch))
# Read params when resuming training
arg_dict = pd.read_csv(f"{saved_models}/args.txt", header=None, index_col=0, sep=";").squeeze("columns").to_dict()
if int(arg_dict["n_epochs"]) > opt.n_epochs or opt.epoch >= opt.n_epochs:
opt.n_epochs = int(input("Total number of epochs to train:"))
# Make sure essential parameters are consistent when resuming training
opt.hr_height = int(arg_dict["hr_height"])
opt.hr_width = int(arg_dict["hr_width"])
opt.factor = int(arg_dict["factor"])
opt.channels = int(arg_dict["channels"])
opt.lr = float(arg_dict["lr"])
opt.b1 = float(arg_dict["b1"])
opt.b2 = float(arg_dict["b2"])
opt.decay_epoch = int(arg_dict["decay_epoch"])
else:
now = datetime.now()
imgs_out = f"{now.year}_{now.month}_{now.day}_{now.hour}_{now.minute}_images"
saved_models = f"{now.year}_{now.month}_{now.day}_{now.hour}_{now.minute}_saved_models"
os.makedirs(imgs_out, exist_ok=False)
os.makedirs(saved_models, exist_ok=False)
# Write arguments to file
with open(f"{saved_models}/args.txt", "w") as file:
for arg in vars(opt):
file.write(f"{arg};{getattr(opt, arg)}\n")
# addition = 0
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
dataloader = DataLoader(
ImageDataset(datadir, hr_shape=hr_shape, factor=opt.factor),
# ImageDataset(os.path.join(os.getcwd(), "data/tiles/HR"), os.path.join(os.getcwd(), "data/tiles/LR")),
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
)
# ----------
# Training
# ----------
if train:
for epoch in range(opt.epoch, opt.n_epochs):
for i, imgs in enumerate(dataloader):
# Configure model input
imgs_lr = Variable(imgs["lr"].type(Tensor))
# print("LR shape", imgs_lr.shape)
imgs_hr = Variable(imgs["hr"].type(Tensor))
# print("HR shape", imgs_hr.shape)
# Adversarial ground truths
if torch.cuda.device_count() > 1 and opt.parallel == 1:
valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.module.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.module.output_shape))), requires_grad=False)
else:
valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
# ------------------
# Train Generators
# ------------------
optimizer_G.zero_grad()
# Generate a high resolution image from low resolution input
gen_hr = generator(imgs_lr)
# print("gen_hr shape:", gen_hr.shape)
# print("valid shape:", valid.shape)
# Adversarial loss
loss_GAN = criterion_GAN(discriminator(gen_hr), valid)
# Content loss
if opt.channels == 1: # Duplicate single channel to 3 channel tensor to match vgg19's expected input
fe_gen_hr = gen_hr.expand(-1, 3, -1, -1)
fe_imgs_hr = imgs_hr.expand(-1, 3, -1, -1)
else:
fe_gen_hr = gen_hr
fe_imgs_hr = imgs_hr
gen_features = feature_extractor(fe_gen_hr)
real_features = feature_extractor(fe_imgs_hr)
# loss_content = criterion_content(gen_features, real_features.detach())
loss_content = temp_err(gen_hr, imgs_hr)
# Total loss
loss_G = loss_content + 1e-3 * loss_GAN
if torch.cuda.device_count() > 1:
loss_G.sum().backward()
else:
loss_G.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Loss of real and fake images
loss_real = criterion_GAN(discriminator(imgs_hr), valid)
loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)
# Total loss
loss_D = (loss_real + loss_fake) / 2
if torch.cuda.device_count() > 1 and opt.parallel == 1:
loss_D.sum().backward()
else:
loss_D.backward()
optimizer_D.step()
# --------------
# Log Progress
# --------------
sys.stdout.write(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]\n"
% (epoch, opt.n_epochs, i, len(dataloader), loss_D.mean().item(), loss_G.mean().item())
)
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
# Save image grid with upsampled inputs, outputs and ground truth
imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=opt.factor)
gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
imgs_hr = make_grid(imgs_hr, nrow=1, normalize=True)
img_grid = torch.cat((imgs_lr, gen_hr, imgs_hr), -1)
save_image(img_grid, f"{imgs_out}/{epoch}_{i}.png", normalize=False)
if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
# Save model checkpoints
if torch.cuda.device_count() > 1 and opt.parallel == 1:
torch.save(generator.module.state_dict(), f"{saved_models}/generator_%d.pth" % epoch)
torch.save(discriminator.module.state_dict(), f"{saved_models}/discriminator_%d.pth" % epoch)
else:
torch.save(generator.state_dict(), f"{saved_models}/generator_%d.pth" % epoch)
torch.save(discriminator.state_dict(), f"{saved_models}/discriminator_%d.pth" % epoch)
I hope someone can point me in the right direction. Thanks!