Hi, I have reimplemented the GAN for grayscale radiology data using the following Github Repo:
GitHub - mdraw/BMSG-GAN at img_channels
I want to generate fake samples using generate_samples.py code but facing following error:
Traceback (most recent call last):
File "generate_samples.py", line 144, in <module>
main(parse_arguments())
File "generate_samples.py", line 113, in main
th.load(args.generator_file)
File "/home/r00206978/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1498, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:
size mismatch for module.rgb_converters.0.weight: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 512, 1, 1]).
size mismatch for module.rgb_converters.0.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for module.rgb_converters.1.weight: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 512, 1, 1]).
size mismatch for module.rgb_converters.1.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for module.rgb_converters.2.weight: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 512, 1, 1]).
size mismatch for module.rgb_converters.2.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for module.rgb_converters.3.weight: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 512, 1, 1]).
size mismatch for module.rgb_converters.3.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for module.rgb_converters.4.weight: copying a param with shape torch.Size([1, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 256, 1, 1]).
size mismatch for module.rgb_converters.4.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for module.rgb_converters.5.weight: copying a param with shape torch.Size([1, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 128, 1, 1]).
size mismatch for module.rgb_converters.5.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).
I trained a GAN with depth=6 and latent_size=512. I am running following code for generating new samples but failed:
python generate_samples.py --generator_file results/GAN_GEN_450.pth --latent_size 512 --depth 6 --out_depth 4 --num_samples 200 --out_dir results/imgs/
The generate_samples.py code:
""" Generate single image samples from a particular depth of a model """
import argparse
import torch as th
import numpy as np
import os
from torch.backends import cudnn
from MSG_GAN.GAN import Generator
from torch.nn.functional import interpolate
#from scipy.misc import imsave
import imageio
from tqdm import tqdm
# turn on the fast GPU processing mode on
cudnn.benchmark = True
# set the manual seed
# th.manual_seed(3)
def parse_arguments():
"""
default command line argument parser
:return: args => parsed command line arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument("--generator_file", action="store", type=str,
help="pretrained weights file for generator", required=True)
parser.add_argument("--latent_size", action="store", type=int,
default=256,
help="latent size for the generator")
parser.add_argument("--depth", action="store", type=int,
default=9,
help="depth of the network. **Starts from 1")
parser.add_argument("--out_depth", action="store", type=int,
default=6,
help="output depth of images. **Starts from 0")
parser.add_argument("--num_samples", action="store", type=int,
default=300,
help="number of synchronized grids to be generated")
parser.add_argument("--out_dir", action="store", type=str,
default="interp_animation_frames/",
help="path to the output directory for the frames")
args = parser.parse_args()
return args
def adjust_dynamic_range(data, drange_in=(-1, 1), drange_out=(0, 1)):
"""
adjust the dynamic colour range of the given input data
:param data: input image data
:param drange_in: original range of input
:param drange_out: required range of output
:return: img => colour range adjusted images
"""
if drange_in != drange_out:
scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (
np.float32(drange_in[1]) - np.float32(drange_in[0]))
bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
data = data * scale + bias
return th.clamp(data, min=0, max=1)
def progressive_upscaling(images):
"""
upsamples all images to the highest size ones
:param images: list of images with progressively growing resolutions
:return: images => images upscaled to same size
"""
with th.no_grad():
for factor in range(1, len(images)):
images[len(images) - 1 - factor] = interpolate(
images[len(images) - 1 - factor],
scale_factor=pow(2, factor)
)
return images
def main(args):
"""
Main function for the script
:param args: parsed command line arguments
:return: None
"""
print("Creating generator object ...")
# create the generator object
gen = th.nn.DataParallel(Generator(
depth=args.depth,
latent_size=args.latent_size
))
print("Loading the generator weights from:", args.generator_file)
# load the weights into it
gen.load_state_dict(
th.load(args.generator_file)
)
# path for saving the files:
save_path = args.out_dir
print("Generating scale synchronized images ...")
for img_num in tqdm(range(1, args.num_samples + 1)):
# generate the images:
with th.no_grad():
point = th.randn(1, args.latent_size)
point = (point / point.norm()) * (args.latent_size ** 0.5)
ss_images = gen(point)
# resize the images:
ss_images = [adjust_dynamic_range(ss_image) for ss_image in ss_images]
ss_images = progressive_upscaling(ss_images)
ss_image = ss_images[args.out_depth]
# save the ss_image in the directory
imageio.imwrite(os.path.join(save_path, str(img_num) + ".png"),
ss_image.squeeze(0).permute(1, 2, 0).cpu())
print("Generated %d images at %s" % (args.num_samples, save_path))
if __name__ == '__main__':
main(parse_arguments())
However, this code works fine when generate RGB images from trained GAN on RGB data. But not working for generating radiology images when GAN trained on radiology data. Could you please help me where could I need to change this code??