Error saving all parameters

Hi, I am currently trying to implement a variant of autoencoder as follows:

    class AE_net(nn.Module):
        def __init__(self, node_num_1, node_num_2, activations, hierarchical=1, hi_variant=2):
            super(AE_net, self).__init__()
            self._activations = activations
            self._hierarchical=hierarchical
            self._hi_variant = hi_variant
            encoder_list = [nn.Sequential(*[nn.Linear(node_num_1[item], node_num_1[item + 1]), nn.Tanh()])
                            for item in range(len(node_num_1) - 2)]
            self._encoder_1 = nn.Sequential(*encoder_list)
            if not hierarchical:
                self._encoder_2 = [nn.Sequential(nn.Linear(node_num_1[-2], node_num_1[-1]), nn.Tanh())]
                decoder_list = [[nn.Linear(node_num_2[item], node_num_2[item + 1]), nn.Tanh()]
                                for item in range(len(node_num_2) - 1)]
                self._decoder = nn.Sequential(*[nn.Sequential(*item) for item in decoder_list])
            else:
                self._encoder_2 = [nn.Sequential(nn.Linear(node_num_1[-2], 1), nn.Tanh())
                                   for _ in range(node_num_1[-1])]
                if hi_variant == 2:
                    temp_node_num_2 = node_num_2[:]
                    temp_node_num_2[0] = 1
                    self._decoder = [nn.Sequential(*[nn.Sequential(
                        nn.Linear(temp_node_num_2[item], temp_node_num_2[item + 1]), nn.Tanh())
                        for item in range(len(temp_node_num_2) - 1)])
                                       for _ in range(node_num_2[0])]
            return

        def forward(self, x):
            temp = self._encoder_1(x)
            latent_z_split = [item_l(temp) for item_l in self._encoder_2]
            latent_z = torch.cat(latent_z_split, dim=-1)
            if not self._hierarchical:
                rec_x = self._decoder(latent_z)
            elif self._hi_variant == 2:
                temp_decoded = [self._decoder[item](latent_z_split[item]) for item in range(len(self._decoder))]
                decoded_list = [temp_decoded[0]]
                for item in temp_decoded[1:]:
                    decoded_list.append(torch.add(decoded_list[-1], item))
                rec_x = torch.cat(decoded_list, dim=-1)
            return rec_x, latent_z

But when I save the model using torch.save(the_model.state_dict(), PATH), only parameters of _encoder_1 are saved. All parameters in other parts (lists of nn.Module objects) are not saved. Does anyone know what might be the reason and how I could save all these parameters?

Thanks!

use ModuleList instead of python list https://pytorch.org/docs/master/nn.html#modulelist. pytorch can’t know about things you hide in a plain list.

1 Like