class MMGCN(torch.nn.Module):
def init(self, features, edge_index, batch_size, num_user, num_item, aggr_mode, concate, num_layer, has_id,
dim_x):
super(MMGCN, self).init()
self.batch_size = batch_size
self.num_user = num_user
self.num_item = num_item
self.aggr_mode = aggr_mode
self.concate = concate
self.edge_index = torch.tensor(edge_index).t().contiguous().cuda()
self.edge_index = torch.cat((self.edge_index, self.edge_index[[1, 0]]), dim=1)
t_feat = features
self.t_feat = torch.tensor(t_feat, dtype=torch.float).cuda()
self.t_gcn = GCN(self.t_feat, self.edge_index, batch_size, num_user, num_item, dim_x, self.aggr_mode,
self.concate, num_layer=num_layer, has_id=has_id, dim_latent=25)
self.id_embedding = nn.init.xavier_normal_(torch.rand((num_user + num_item, dim_x), requires_grad=True)).cuda()
self.result_embed = nn.init.xavier_normal_(torch.rand((num_user + num_item, dim_x))).cuda()
def forward(self, user_nodes, pos_item_nodes, neg_item_nodes):
t_rep = self.t_gcn(self.id_embedding)
representation = (t_rep) / 1
self.result_embed = representation
user_tensor = representation[user_nodes]
pos_item_tensor = representation[pos_item_nodes]
neg_item_tensor = representation[neg_item_nodes]
pos_scores = torch.sum(user_tensor * pos_item_tensor, dim=1)
neg_scores = torch.sum(user_tensor * neg_item_tensor, dim=1)
return pos_scores, neg_scores