I trained a ProGAN model (using this repo) and now I want to use it to generate an image. The model was trained on a GPU cluster, and now I am using a single GPU to run it.
Following the instructions in the repo page, I load the pth
file using nn.DataParallel. In detail, these are the commands I give:
import torch as th
from pro_gan_pytorch import PRO_GAN as pg2
device = th.device("cuda" if th.cuda.is_available() else "cpu")
gen = th.nn.DataParallel(pg.Generator(depth=6))
#gen = (pg2.Generator())
progan.gen.load_state_dict(th.load("GAN_GEN_5.pth", map_location=str(device)))
As a result, I get the error message below. What should I do to fix the issue?
RuntimeError: Error(s) in loading state_dict for DataParallel:
Unexpected key(s) in state_dict: "module.layers.4.conv_1.weight", "module.layers.4.conv_1.bias", "module.layers.4.conv_2.weight", "module.layers.4.conv_2.bias", "module.rgb_converters.5.weight", "module.rgb_converters.5.bias".
size mismatch for module.rgb_converters.0.weight: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 512, 1, 1]).
size mismatch for module.rgb_converters.0.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for module.rgb_converters.1.weight: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 512, 1, 1]).
size mismatch for module.rgb_converters.1.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for module.rgb_converters.2.weight: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 512, 1, 1]).
size mismatch for module.rgb_converters.2.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for module.rgb_converters.3.weight: copying a param with shape torch.Size([1, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 512, 1, 1]).
size mismatch for module.rgb_converters.3.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for module.rgb_converters.4.weight: copying a param with shape torch.Size([1, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([3, 256, 1, 1]).
size mismatch for module.rgb_converters.4.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).