Here is the whole model architecture
class Encoder(nn.Module):
def __init__(self,features,adj_matrix,embed_dim,feature_dim,agg,cuda):#in our case its ppi feat dim
super(Encoder,self).__init__()
self.features = features
self.adj_matrix = adj_matrix
self.embed_dim = embed_dim
self.weight = nn.Parameter(torch.torch.FloatTensor(embed_dim,2*feature_dim))
init.xavier_uniform(self.weight)
self.aggregation = agg
self.cuda = cuda
self.aggregation.cuda = cuda
def forward(self,nodes):
#print("checking one neigh ",[self.adj_matrix[int(nodes[0])],self.adj_matrix[int(nodes[2])]])
agg = self.aggregation(nodes,[self.adj_matrix[int(node)] for node in nodes],num_sample=40)
#print("checks ",agg.is_cuda,"nodes ",torch.LongTensor(nodes).is_cuda)
if self.cuda:
print(type(nodes))
print(type(self.features),torch.LongTensor(nodes.cpu()).is_cuda)
node_embed = self.features(torch.LongTensor(nodes))
else:
node_embed = self.features(torch.LongTensor(nodes))#.cuda()
print("working one agg of nodes too")#,node_embed.is_cuda,agg.is_cuda)
fe1 = torch.cat([node_embed,agg],dim=1)
#print("checking fe1 ",fe1.is_cuda)
if self.weight.is_cuda:
fe1 = fe1.cuda()
fe2 = F.relu(self.weight.mm(fe1.t()))
return fe2
class MeanAggregation(nn.Module):
def __init__(self,features,num,cuda):
super(MeanAggregation,self).__init__()
self.features = features#.cpu()
self.cuda = cuda
self.num = num
def forward(self,nodes,neigh1,num_sample=4):
print("mean agg and cuda ",self.num, self.cuda)
num_sample=40
_sample = random.sample
_set = set
neigh = [_set(_sample(to_neigh, num_sample,)) if len(to_neigh) >= num_sample else to_neigh for to_neigh in neigh1]
unique_node_list = list(set.union(*neigh))
unique_node_map = {n:i for i,n in enumerate(unique_node_list)}
if self.cuda:
embed_matrix = self.features(torch.LongTensor(unique_node_list).cuda())
else:
embed_matrix = self.features(torch.LongTensor(unique_node_list))
mask = Variable(torch.zeros(len(nodes),len(unique_node_list)))
# print("mask variable deficning taking ",unique_node_map[116])
column_indices = []
for samp_neigh in neigh:
for n in samp_neigh:
# print(n)
column_indices.append(unique_node_map[n])
row_indices = [i for i in range(len(neigh)) for j in range(len(neigh[i]))]
mask[row_indices, column_indices] = 1
# print("before mask sum",mask.shape)
#self.cuda=True
if self.cuda:
mask = mask.cuda()
mask_cp =mask.sum(1, keepdim=True)
mask = mask/mask_cp
print("before mask mm with embed_matrix")
mkp = mask.mm(embed_matrix)
return mkp
class GraphSage(nn.Module):
def __init__(self,num_classes, enc):
super(GraphSage,self).__init__()
self.fc1 = nn.Parameter(torch.torch.FloatTensor(num_classes,enc.embed_dim))#.cuda()
self.enc = enc
init.xavier_uniform(self.fc1)
self.xent = nn.L1Loss()#check from paper once #nn.CrossEntropyLoss()
#paper uses sigmid with logistics as loss........implement
self.softmax = torch.nn.Sigmoid()#use signmoid as we gave multiclass labels
def forward(self, nodes):
enc_op = self.enc(nodes)
fe1 = self.fc1.mm(enc_op)
return self.softmax(fe1)
def loss(self,nodes,labels):
op = self.forward(nodes)
#print(op.shape,labels.shape)
loss = self.xent(op.t(),labels.cuda())
return loss