I have an operation, that can be reproduced using the following code
reg_sig = torch.randn([32, 9, 5]) reg_adj = torch.randn([32, 9, 9, 4]) n_node = 9 n_bond_features = 4 SM_f = nn.Softmax(dim=2) SM_W = nn.Softmax(dim=3) p_f = SM_f(reg_sig) p_W = SM_W(reg_adj) V = torch.zeros([reg_sig.size(0), n_node]) for i in range(n_node): for j in range(n_node): if i is not j: for k in range(n_bond_features): V[:, i] += k * p_W[:, i, j, k]
I would like to get rid of the loop at the end. Here is my implementation
h_vec = torch.arange(n_bond_features).float() inner_sum = torch.einsum('i,bjki->bjk', h_vec, p_W) V = inner_sum-torch.diag_embed(torch.einsum('...ii->...i', inner_sum)) V = V.sum(2)
I was wondering if this is the most efficient way to do it.