Model doesn't register parameters if done through dict

Hi,

I can’t figure out how to make the model register my params if they are initialized as values in a dict.

Example code,

import torch

class EmbeddingModel(torch.nn.Module):

    def __init__(self):
        super(EmbeddingModel, self).__init__()

        self.embeddings = {
            'source': None,
            'target': None
        }

        self.init_embedding("source", 10, 5)
        self.init_embedding("target", 10, 4)

    def init_embedding(self, emb_type, num_items, emb_dim):
        self.embeddings[emb_type] = torch.nn.Embedding(
            num_embeddings=num_items, embedding_dim=emb_dim)

    def forward(self, src_idx):
        pass


model = EmbeddingModel()

for param in model.parameters():
    print(param) # <- No parameters in the model.

Is there a way to do this cleanly?

That’s expected and you would need to use nn.ModuleDict instead of a plain Python dict.

1 Like

I knew it’s gonna be you! :joy:

Didn’t know about ModuleDict. That looks like exactly what I need. Thanks Patrick!

1 Like