import numpy as np
import torch
from data_load import NpyDataset
from main import Generator
generator = Generator(channels=15)
checkpoint = torch.load(‘netd_0.pth’)
generator.load_state_dict(checkpoint, strict=False)
generator.eval()
device = torch.device(‘cpu’)
fix_noises = torch.randn(64, 100, 4, 4).to(device)
dataset = NpyDataset(‘10000_npy/’)
fix_fake_image = generator(fix_noises)
fix_fake_image = (fix_fake_image * (dataset.max - dataset.min)) + dataset.min
fix_fake_image = fix_fake_image[:64].detach().cpu().numpy()
np.save(‘PERM/perm_1’, fix_fake_image)
RuntimeError: Error(s) in loading state_dict for Generator:
size mismatch for main_module.0.weight: copying a param with shape torch.Size([256, 15, 4, 5]) from checkpoint, the shape in current model is torch.Size([100, 1024, 4, 3]).
size mismatch for main_module.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for main_module.1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for main_module.1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for main_module.1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for main_module.1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1024]).
size mismatch for main_module.3.weight: copying a param with shape torch.Size([512, 256, 4, 5]) from checkpoint, the shape in current model is torch.Size([1024, 512, 4, 3]).
size mismatch for main_module.6.weight: copying a param with shape torch.Size([1024, 512, 4, 5]) from checkpoint, the shape in current model is torch.Size([512, 256, 4, 2]).
size mismatch for main_module.6.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for main_module.7.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for main_module.7.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for main_module.7.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for main_module.7.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).