I am trying to implement dot product type attention block. There are two type of implementation with different computation cost. One is $z = QK^T V/n$, the other is z = $Q (K^T V)/n$. Theoretically, the first type will have quadratic complexity, the second one will have linear complexity.
The implementation on cpu is as follows
import time
import torch
b = 5
n = 256
h = 4
dk = 16
dv = 16
Q = torch.rand((b, h, n, dk))
K = torch.rand((b, h, n, dk))
V = torch.rand((b, h, n, dv))
# type 1 : z = (QK)V
times = []
niters = 10000
for i in range(niters):
torch.cuda.synchronize()
start_epoch = time.time()
QK = torch.einsum('bhnd,bhmd->bhnm', Q, K)
v1 = torch.einsum('bhnm,bhmd->bhnd',QK, V)
end_epoch = time.time()
elapsed = end_epoch - start_epoch
times.append(elapsed)
avg_time = sum(times)/niters
print(avg_time) # result on my machine 0.0004075302600860596
# type 2 : z = Q(KV)
times = []
niters = 10000
for i in range(niters):
torch.cuda.synchronize()
start_epoch = time.time()
KV = torch.einsum('bhni,bhnj->bhij', K, V)
v2 = torch.einsum('bhni,bhij->bhnj', Q, KV)
end_epoch = time.time()
elapsed = end_epoch - start_epoch
times.append(elapsed)
avg_time = sum(times)/niters
print(avg_time) # result on my machine 0.00023783543109893798
Here, on cpu case, type 1 is obviously slower than type 2.
However, if we put Q,K,V
on cuda. The result on my machine for type 1 is 0.00014664556980133057, and the result for type 2 is 0.0001561680793762207.
Type 2 on gpu is slower than type 1, which is inconsistent with the cpu result.
I’m quite confuse with this result.