I have an operation, that can be reproduced using the following code
reg_sig = torch.randn([32, 9, 5])
reg_adj = torch.randn([32, 9, 9, 4])
n_node = 9
n_bond_features = 4
SM_f = nn.Softmax(dim=2)
SM_W = nn.Softmax(dim=3)
p_f = SM_f(reg_sig)
p_W = SM_W(reg_adj)
V = torch.zeros([reg_sig.size(0), n_node])
for i in range(n_node):
for j in range(n_node):
if i is not j:
for k in range(n_bond_features):
V[:, i] += k * p_W[:, i, j, k]
I would like to get rid of the loop at the end. Here is my implementation
h_vec = torch.arange(n_bond_features).float()
inner_sum = torch.einsum('i,bjki->bjk', h_vec, p_W)
V = inner_sum-torch.diag_embed(torch.einsum('...ii->...i', inner_sum))
V = V.sum(2)
I was wondering if this is the most efficient way to do it.