How to create multiple nn.Linear() using a loop?

This question may be shallow. I am implementing a GNN model, and for each relation type I want to create a weight matrix Wr using nn.Linear(dim, dim). This is a piece of my code:

class GNN(nn.Module):
    def __init__(self, entity_num, relation_num, dim, args):
        super(GNN, self).__init__()
        self.entity_num = entity_num
        self.relation_num = relation_num
        self.dim = dim
        # create entity and relation embedding lookup  table
        self.entity_embedding = torch.nn.Embedding(self.entity_num, self.dim, padding_idx=self.pad_index)
        self.relation_embedding = torch.nn.Embedding(self.relation_num, self.dim)
        # The next line is a single weight matrix W, but I want to create multiple W
        # self.W = nn.Linear(self.dim, self.dim)
        self.W_list = ?  # [nn.Linear(self.dim, self.dim) for i in range(relation_num)] 

How to implement self.W_list = ?

You might want to use a module list.

self.W_list =  nn.ModuleList()

for i  in range(relation_num):
    self.W_list.append(nn.Linear(self.dim, self.dim))
1 Like

This’s viable! Thank you. @a_d