This question may be shallow. I am implementing a GNN model, and for each relation type I want to create a weight matrix Wr using nn.Linear(dim, dim)
. This is a piece of my code:
class GNN(nn.Module):
def __init__(self, entity_num, relation_num, dim, args):
super(GNN, self).__init__()
self.entity_num = entity_num
self.relation_num = relation_num
self.dim = dim
# create entity and relation embedding lookup table
self.entity_embedding = torch.nn.Embedding(self.entity_num, self.dim, padding_idx=self.pad_index)
self.relation_embedding = torch.nn.Embedding(self.relation_num, self.dim)
# The next line is a single weight matrix W, but I want to create multiple W
# self.W = nn.Linear(self.dim, self.dim)
self.W_list = ? # [nn.Linear(self.dim, self.dim) for i in range(relation_num)]
How to implement self.W_list = ?