Hello, I think I did the same multiplication, but the result is different.
import torch
#Self attention with 100-tokens
n_token=100
c_hidden=128
query=torch.randn(n_token,c_hidden)
key =torch.randn(n_token,c_hidden)
attn_matrix= torch.matmul( query, key.transpose(0,1)) #(query_token_index, key_token_index)
#Graph frame work : every 50-tokens are fully connected with each other
edge_index=torch.ones(n_token,n_token).nonzero(as_tuple=True)# all token is connected with all token
#edge_index = ( source_token_index, destination_token_index )
query_graph=query[edge_index[0]] # query from source_token_index
key_graph=key[edge_index[1]] # key from destination_token_index
attn_graph_1= (query_graph*key_graph).sum(dim=-1)
attn_graph_2= torch.matmul(query_graph[:,None,:],key_graph[:,:,None]).squeeze()
diff_1=torch.abs(attn_matrix[edge_index] - attn_graph_1).sum()
diff_2=torch.abs(attn_matrix[edge_index] - attn_graph_2).sum()
print (diff_1) #0.0157
print (diff_2) #0.0089
I have some idea why ‘diff_1’ is not 0, but I don’t know why ‘diff_2’ is not 0. Does anybody have idea?