Hi,
I can’t figure out how to make the model register my params if they are initialized as values in a dict.
Example code,
import torch
class EmbeddingModel(torch.nn.Module):
def __init__(self):
super(EmbeddingModel, self).__init__()
self.embeddings = {
'source': None,
'target': None
}
self.init_embedding("source", 10, 5)
self.init_embedding("target", 10, 4)
def init_embedding(self, emb_type, num_items, emb_dim):
self.embeddings[emb_type] = torch.nn.Embedding(
num_embeddings=num_items, embedding_dim=emb_dim)
def forward(self, src_idx):
pass
model = EmbeddingModel()
for param in model.parameters():
print(param) # <- No parameters in the model.
Is there a way to do this cleanly?