Unexpected results when running inference after training using DataParallel

I don’t have too much background in machine learning, but I am currently working on upscaling electron microscopy images using a GAN architecture. I have wrapped the models in nn.DataParallel and during training everything works as expected.

However, when applying the obtained weights during inference the results differ greatly from the periodically generated output during training, and I cannot figure out why this is. Inference using weights obtained when not calling DataParallel during training are as expected.

Below is the (relevant) code:

    cuda = torch.cuda.is_available()
    hr_shape = (opt.hr_height, opt.hr_width)

    # Initialize generator and discriminator
    generator = GeneratorResNet(factor=opt.factor)
    discriminator = Discriminator(input_shape=(opt.channels, *hr_shape))
    feature_extractor = FeatureExtractor()

    # Set feature extractor to inference mode
    feature_extractor.eval()

    # Losses
    criterion_GAN = torch.nn.MSELoss()
    temp_err = torch.nn.MSELoss()
    criterion_content = torch.nn.L1Loss()

    if torch.cuda.device_count() > 1 and opt.parallel == 1:
        print("Multiple GPU training")
        temp_err = nn.DataParallel(temp_err)
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
        feature_extractor = nn.DataParallel(feature_extractor)
        criterion_GAN = nn.DataParallel(criterion_GAN)
        criterion_content = nn.DataParallel(criterion_content)

    if cuda:
        temp_err = temp_err.cuda()
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        feature_extractor = feature_extractor.cuda()
        criterion_GAN = criterion_GAN.cuda()
        criterion_content = criterion_content.cuda()

    if opt.epoch != 0:
        inp = input("Saved models folder: ")
        #inp2 = input("Saved images folder: ")
        timestamp = inp.split("_s")[0]
        saved_models = os.path.join(os.getcwd(), inp)
        imgs_out = os.path.join(os.getcwd(), timestamp + "_images")
        # Load pretrained models
        generator.load_state_dict(torch.load(f"{saved_models}/generator_%d.pth" % opt.epoch))
        discriminator.load_state_dict(torch.load(f"{saved_models}/discriminator_%d.pth" % opt.epoch))

        # Read params when resuming training
        arg_dict = pd.read_csv(f"{saved_models}/args.txt", header=None, index_col=0, sep=";").squeeze("columns").to_dict()

        if int(arg_dict["n_epochs"]) > opt.n_epochs or opt.epoch >= opt.n_epochs:
            opt.n_epochs = int(input("Total number of epochs to train:"))

        # Make sure essential parameters are consistent when resuming training
        opt.hr_height = int(arg_dict["hr_height"])
        opt.hr_width = int(arg_dict["hr_width"])
        opt.factor = int(arg_dict["factor"])
        opt.channels = int(arg_dict["channels"])
        opt.lr = float(arg_dict["lr"])
        opt.b1 = float(arg_dict["b1"])
        opt.b2 = float(arg_dict["b2"])
        opt.decay_epoch = int(arg_dict["decay_epoch"])

    else:
        now = datetime.now()
        imgs_out = f"{now.year}_{now.month}_{now.day}_{now.hour}_{now.minute}_images"
        saved_models = f"{now.year}_{now.month}_{now.day}_{now.hour}_{now.minute}_saved_models"
        os.makedirs(imgs_out, exist_ok=False)
        os.makedirs(saved_models, exist_ok=False)

        # Write arguments to file
        with open(f"{saved_models}/args.txt", "w") as file:
            for arg in vars(opt):
                file.write(f"{arg};{getattr(opt, arg)}\n")
        # addition = 0

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

    Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

    dataloader = DataLoader(
        ImageDataset(datadir, hr_shape=hr_shape, factor=opt.factor),
        # ImageDataset(os.path.join(os.getcwd(), "data/tiles/HR"), os.path.join(os.getcwd(), "data/tiles/LR")),
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=opt.n_cpu,
    )

    # ----------
    #  Training
    # ----------
    if train:
        for epoch in range(opt.epoch, opt.n_epochs):
            for i, imgs in enumerate(dataloader):
                # Configure model input
                imgs_lr = Variable(imgs["lr"].type(Tensor))
                # print("LR shape", imgs_lr.shape)
                imgs_hr = Variable(imgs["hr"].type(Tensor))
                # print("HR shape", imgs_hr.shape)
                # Adversarial ground truths
                if torch.cuda.device_count() > 1 and opt.parallel == 1:
                    valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.module.output_shape))), requires_grad=False)
                    fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.module.output_shape))), requires_grad=False)
                else:
                    valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
                    fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
                # ------------------
                #  Train Generators
                # ------------------

                optimizer_G.zero_grad()

                # Generate a high resolution image from low resolution input
                gen_hr = generator(imgs_lr)
                # print("gen_hr shape:", gen_hr.shape)
                # print("valid shape:", valid.shape)
                # Adversarial loss
                loss_GAN = criterion_GAN(discriminator(gen_hr), valid)

                # Content loss
                if opt.channels == 1:  # Duplicate single channel to 3 channel tensor to match vgg19's expected input
                    fe_gen_hr = gen_hr.expand(-1, 3, -1, -1)
                    fe_imgs_hr = imgs_hr.expand(-1, 3, -1, -1)
                else:
                    fe_gen_hr = gen_hr
                    fe_imgs_hr = imgs_hr

                gen_features = feature_extractor(fe_gen_hr)
                real_features = feature_extractor(fe_imgs_hr)
                # loss_content = criterion_content(gen_features, real_features.detach())
                loss_content = temp_err(gen_hr, imgs_hr)

                # Total loss
                loss_G = loss_content + 1e-3 * loss_GAN

                if torch.cuda.device_count() > 1:
                    loss_G.sum().backward()
                else:
                    loss_G.backward()

                optimizer_G.step()

                # ---------------------
                #  Train Discriminator
                # ---------------------

                optimizer_D.zero_grad()

                # Loss of real and fake images
                loss_real = criterion_GAN(discriminator(imgs_hr), valid)
                loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)

                # Total loss
                loss_D = (loss_real + loss_fake) / 2
                if torch.cuda.device_count() > 1 and opt.parallel == 1:
                    loss_D.sum().backward()
                else:
                    loss_D.backward()

                optimizer_D.step()

                # --------------
                #  Log Progress
                # --------------

                sys.stdout.write(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]\n"
                    % (epoch, opt.n_epochs, i, len(dataloader), loss_D.mean().item(), loss_G.mean().item())
                )

                batches_done = epoch * len(dataloader) + i
                if batches_done % opt.sample_interval == 0:
                    # Save image grid with upsampled inputs, outputs and ground truth
                    imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=opt.factor)
                    gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
                    imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
                    imgs_hr = make_grid(imgs_hr, nrow=1, normalize=True)
                    img_grid = torch.cat((imgs_lr, gen_hr, imgs_hr), -1)
                    save_image(img_grid, f"{imgs_out}/{epoch}_{i}.png", normalize=False)

            if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
                # Save model checkpoints
                if torch.cuda.device_count() > 1 and opt.parallel == 1:
                    torch.save(generator.module.state_dict(), f"{saved_models}/generator_%d.pth" % epoch)
                    torch.save(discriminator.module.state_dict(), f"{saved_models}/discriminator_%d.pth" % epoch)
                else:
                    torch.save(generator.state_dict(), f"{saved_models}/generator_%d.pth" % epoch)
                    torch.save(discriminator.state_dict(), f"{saved_models}/discriminator_%d.pth" % epoch)

I hope someone can point me in the right direction. Thanks!