Empty model.state_dict() after training generator?

Hello!

I have trained a generator model as defined by:

class Downscale(nn.Module):
    def __init__(self, in_size, out_size, normalize = True, dropout = 0.0):
        super(Downscale, self).__init__()
        
        model = [nn.Conv2d(
            in_size,
            out_size,
            kernel_size = 4,
            stride = 2,
            padding = 1,
            bias = False
            )]
        
        if normalize:
            model.append(
                nn.BatchNorm2d(out_size, 0.8)
                )
            
        model.append(
            nn.LeakyReLU(0.2)
            )
        
        if dropout:
            
            model.append(
                nn.Dropout(dropout)
                )
        
        self.model = nn.Sequential(*model)
        
    def forward(self, x):
        return self.model(x)
##############################################################################       
# #############################################################################   
class Upscale(nn.Module):
    def __init__(self, in_size, out_size, dropout = 0.0):
        super(Upscale, self).__init__()
        
        model =[
            nn.ConvTranspose2d(
                in_size,
                out_size,
                kernel_size = 4,
                stride = 2,
                padding = 1,
                bias = False
                ),
            nn.BatchNorm2d(out_size, 0.8),
            nn.ReLU(inplace = True)
            ]
        
        if dropout:
            model.append(
                nn.Dropout(dropout)
                )
            
        self.model = nn.Sequential(*model)
        
    def forward(self, x, skip_input):
        x = self.model(x)
        out = torch.cat((x, skip_input), dim = 1)
        return out
##############################################################################       
# #############################################################################   

class Generator(nn.Module):
    def __init__(self, features_g, num_channels):
        super(Generator, self).__init__()
        self.features_g = features_g
        self.num_channels = num_channels
        
        
    def build(self):
        #  input: channels X 64 X 64
        self.down1 = Downscale(
            in_size = self.num_channels, 
            out_size = self.features_g, 
            normalize = False)
        
        # input: features_g X 32 X 32
        self.down2 = Downscale(
            in_size = self.features_g, 
            out_size = self.features_g * 2)
        
        
        # input: (features_g * 2) X 16 x 16
        self.down3 = Downscale(
            in_size = (self.features_g * 2 + self.num_channels),
            out_size = self.features_g * 4,
            dropout = 0.5
            )
        
        # input: (features_g * 4) X 8 X 8
        self.down4 = Downscale(
            in_size = self.features_g * 4,
            out_size = self.features_g * 8,
            dropout = 0.5
            )
        
        # input: (features_g * 8) X 4 X 4
        self.down5 = Downscale(
            in_size = self.features_g * 8,
            out_size = self.features_g * 8,
            dropout = 0.5
            )
        
        # input: (features_g * 8) X 2 X 2
        self.down6 = Downscale(
            in_size = self.features_g * 8,
            out_size = self.features_g * 8,
            dropout = 0.5
            )
        ## state: (features_g * 8) X 1 X 1 ##
        
        # input: (features_g * 8) X 1 X 1 
        self.up1 = Upscale(
            in_size = self.features_g * 8,
            out_size = self.features_g * 8,
            dropout = 0.5
            )

        # input: (features_g * 8) X 2 X 2
        self.up2 = Upscale(
            in_size = self.features_g * 16,
            out_size = self.features_g * 8, 
            dropout = 0.5
            )
        
        # input: (features_g * 8) X 4 X 4
        self.up3 = Upscale(
            in_size = self.features_g * 16,
            out_size = self.features_g * 4,
            dropout = 0.5
            )
    
        # input: (features_g * 4) X 8 X 8
        self.up4 = Upscale(
            in_size = self.features_g * 8,
            out_size = self.features_g * 2
            )
        
        # input: (features_g * 2) X 16 X 16
        self.up5 = Upscale(
            in_size = (self.features_g * 4 + self.num_channels),
            out_size = self.features_g
            )

        ## state: features_g X 32 X 32 ##
        
        final = [
            nn.Upsample(scale_factor = 2),
            
            # input: features_g X 64 X 64
            
            nn.Conv2d(
                in_channels = self.features_g * 2, 
                out_channels = self.num_channels,
                kernel_size = 3,
                stride = 1,
                padding = 1
                ),
            
            # input: num_channels X 64 X 64
            
            nn.Tanh()
            ]
         
        self.final = nn.Sequential(*final)
            
    def forward(self, input, constraint_map):
        
        d1 = self.down1(input)
        d2 = self.down2(d1)
        d2 = torch.cat((d2, constraint_map), dim = 1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        u1 = self.up1(d6, d5)
        u2 = self.up2(u1, d4)
        u3 = self.up3(u2, d3)
        u4 = self.up4(u3, d2)
        u5 = self.up5(u4, d1)
        
        return self.final(u5)

    def define_optim(self, learning_rate, beta1):
        self.optimizer = optim.Adam(self.parameters(), lr = learning_rate, betas = (beta1, 0.999))
     
    @staticmethod    
    def init_weights(layers):
        classname = layers.__class__.__name__
        
        if classname.find('Conv') != -1:
            nn.init.normal_(layers.weight.data, 0.0, 0.02)
            
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(layers.weight.data, 1.0, 0.02)
            nn.init.constant_(layers.bias.data, 0)

Once trained, I save the state with

torch.save(gen.state_dict(), "GENERATOR/gen.pt")

However, upon loading

model = Generator(features_g = 64, num_channels = 3)

for param in torch.load('GENERATOR/gen.pt'):
    print(param)
    
print(model.state_dict())
for param in model.state_dict():
    print(param)

toch.load gives out the expected keys, but model.state_dict() returns OrderedDict().

Any advice?

Thanks!

Based on your code snippet the Generator.__init__ method doesn’t initialize any modules and just stores the features and number of channels.
I thus guess that you’ve either forgotten to call self.build() in __init__ or model.build() to actually initialize the modules.

1 Like

Thanks! I’ll give it a go and see what happens :smiley: