Efficient pytorch broadcasting not found to avoid bottlenek opeartions

I have the following implementation in my pytorch based code which involves a nested for loop. The nested for loop along with if condition makes the code very slow to exceute. I attempted to avoid the nested loop to involve the broadcasting concepts in numpy and pytorch but that didnot yield any result. Any help regarding avoiding the for loops will be appreciated.

Here are the links I have read

    #!/usr/bin/env python
    # coding: utf-8

    import torch
    
    batch_size=32
    mask=torch.FloatTensor(batch_size).uniform_() > 0.8
    
    teacher_count=510
    student_count=420
    feature_dim=750
    student_output=torch.zeros([batch_size,student_count])
    teacher_output=torch.zeros([batch_size,teacher_count])
    
    student_adjacency_mat=torch.randint(0,1,(student_count,student_count))
    teacher_adjacency_mat=torch.randint(0,1,(teacher_count,teacher_count))

    student_feat=torch.rand([batch_size,feature_dim])
    student_graph=torch.rand([student_count,feature_dim])
    teacher_feat=torch.rand([batch_size,feature_dim])
    teacher_graph=torch.rand([teacher_count,feature_dim])


    for m in range(batch_size):
        if mask[m]==1:
            for i in range(student_count):
                for j in range(student_count):
                    student_output[m][i]=student_output[m][i]+student_adjacency_mat[i][j]*torch.dot(student_feat[m],student_graph[j])
        if mask[m]==0:
            for i in range(teacher_count):
                for j in range(teacher_count):
                    teacher_output[m][i]=teacher_output[m][i]+teacher_adjacency_mat[i][j]*torch.dot(teacher_feat[m],teacher_graph[j])

References: [1] Broadcasting semantics — PyTorch 1.13 documentation
[2] Broadcasting — NumPy v1.24 Manual

This should work:

# create reference
batch_size=32
mask=torch.FloatTensor(batch_size).uniform_() > 0.8

teacher_count=40
student_count=50
feature_dim=60
student_output=torch.zeros([batch_size,student_count])
teacher_output=torch.zeros([batch_size,teacher_count])

student_adjacency_mat=torch.randint(0,2,(student_count,student_count))
teacher_adjacency_mat=torch.randint(0,2,(teacher_count,teacher_count))

student_feat=torch.rand([batch_size,feature_dim])
student_graph=torch.rand([student_count,feature_dim])
teacher_feat=torch.rand([batch_size,feature_dim])
teacher_graph=torch.rand([teacher_count,feature_dim])

# loop approach
for m in range(batch_size):
    if mask[m]==1:
        for i in range(student_count):
            for j in range(student_count):
                student_output[m][i]=student_output[m][i]+student_adjacency_mat[i][j]*torch.dot(student_feat[m],student_graph[j])
    if mask[m]==0:
        for i in range(teacher_count):
            for j in range(teacher_count):
                teacher_output[m][i]=teacher_output[m][i]+teacher_adjacency_mat[i][j]*torch.dot(teacher_feat[m],teacher_graph[j])
                
# without loops
feat = torch.matmul(student_feat,student_graph.T)
feat = student_adjacency_mat.unsqueeze(1) * feat
out = feat.sum(2).T
student_out = mask.float().unsqueeze(1) * out
print((student_out - student_output).abs().max())
# tensor(9.1553e-05)

feat = torch.matmul(teacher_feat, teacher_graph.T)
feat= teacher_adjacency_mat.unsqueeze(1) * feat
out = feat.sum(2).T
teacher_out = (~mask).float().unsqueeze(1) * out
print((teacher_out - teacher_output).abs().max())
# tensor(9.1553e-05)
1 Like

Thanks a lot. Your approach saved a lot of time for me.

1 Like