I have the following implementation in my pytorch based code which involves a nested for loop. The nested for loop along with if condition makes the code very slow to exceute. I attempted to avoid the nested loop to involve the broadcasting concepts in numpy and pytorch but that didnot yield any result. Any help regarding avoiding the for loops will be appreciated.
Here are the links I have read
#!/usr/bin/env python
# coding: utf-8
import torch
batch_size=32
mask=torch.FloatTensor(batch_size).uniform_() > 0.8
teacher_count=510
student_count=420
feature_dim=750
student_output=torch.zeros([batch_size,student_count])
teacher_output=torch.zeros([batch_size,teacher_count])
student_adjacency_mat=torch.randint(0,1,(student_count,student_count))
teacher_adjacency_mat=torch.randint(0,1,(teacher_count,teacher_count))
student_feat=torch.rand([batch_size,feature_dim])
student_graph=torch.rand([student_count,feature_dim])
teacher_feat=torch.rand([batch_size,feature_dim])
teacher_graph=torch.rand([teacher_count,feature_dim])
for m in range(batch_size):
if mask[m]==1:
for i in range(student_count):
for j in range(student_count):
student_output[m][i]=student_output[m][i]+student_adjacency_mat[i][j]*torch.dot(student_feat[m],student_graph[j])
if mask[m]==0:
for i in range(teacher_count):
for j in range(teacher_count):
teacher_output[m][i]=teacher_output[m][i]+teacher_adjacency_mat[i][j]*torch.dot(teacher_feat[m],teacher_graph[j])
# create reference
batch_size=32
mask=torch.FloatTensor(batch_size).uniform_() > 0.8
teacher_count=40
student_count=50
feature_dim=60
student_output=torch.zeros([batch_size,student_count])
teacher_output=torch.zeros([batch_size,teacher_count])
student_adjacency_mat=torch.randint(0,2,(student_count,student_count))
teacher_adjacency_mat=torch.randint(0,2,(teacher_count,teacher_count))
student_feat=torch.rand([batch_size,feature_dim])
student_graph=torch.rand([student_count,feature_dim])
teacher_feat=torch.rand([batch_size,feature_dim])
teacher_graph=torch.rand([teacher_count,feature_dim])
# loop approach
for m in range(batch_size):
if mask[m]==1:
for i in range(student_count):
for j in range(student_count):
student_output[m][i]=student_output[m][i]+student_adjacency_mat[i][j]*torch.dot(student_feat[m],student_graph[j])
if mask[m]==0:
for i in range(teacher_count):
for j in range(teacher_count):
teacher_output[m][i]=teacher_output[m][i]+teacher_adjacency_mat[i][j]*torch.dot(teacher_feat[m],teacher_graph[j])
# without loops
feat = torch.matmul(student_feat,student_graph.T)
feat = student_adjacency_mat.unsqueeze(1) * feat
out = feat.sum(2).T
student_out = mask.float().unsqueeze(1) * out
print((student_out - student_output).abs().max())
# tensor(9.1553e-05)
feat = torch.matmul(teacher_feat, teacher_graph.T)
feat= teacher_adjacency_mat.unsqueeze(1) * feat
out = feat.sum(2).T
teacher_out = (~mask).float().unsqueeze(1) * out
print((teacher_out - teacher_output).abs().max())
# tensor(9.1553e-05)