def run(x,y, z, d_k, mask, dropout, zero_pad=False, gamma=None):
scores = torch.matmul(x,y.transpose(-2, -1)) / \
math.sqrt(d_k) # BS, 8, seqlen, seqlen
bs, head, seqlen = scores.size(0), scores.size(1), scores.size(2)
x1 = torch.arange(seqlen).expand(seqlen, -1).to(device)
x2 = x1.transpose(0, 1).contiguous()
with torch.no_grad():
scores_ = scores.masked_fill(mask == 0, -1e32)
scores_ = F.softmax(scores_, dim=-1) # BS,8,seqlen,seqlen
scores_ = scores_ * mask.float().to(device)
distcum_scores = torch.cumsum(scores_, dim=-1) # bs, 8, sl, sl
disttotal_scores = torch.sum(
scores_, dim=-1, keepdim=True) # bs, 8, sl, 1
position_effect = torch.abs(
x1-x2)[None, None, :, :].type(torch.FloatTensor).to(device) # 1, 1, seqlen, seqlen
# bs, 8, sl, sl positive distance
dist_scores = torch.clamp(
(disttotal_scores-distcum_scores)*position_effect, min=0.)
dist_scores = dist_scores.sqrt().detach()
m = nn.Softplus()
gamma = -1. * m(gamma).unsqueeze(0) # 1,8,1,1
# Now after do exp(gamma*distance) and then clamp to 1e-5 to 1e5
total_effect = torch.clamp(torch.clamp(
(dist_scores*gamma).exp(), min=1e-5), max=1e5)
scores = scores * total_effect
scores.masked_fill_(mask == 0, -1e32)
scores = F.softmax(scores, dim=-1) # BS,8,seqlen,seqlen
if zero_pad:
pad_zero = torch.zeros(bs, head, 1, seqlen).to(device)
scores = torch.cat([pad_zero, scores[:, :, 1:, :]], dim=2)
scores = dropout(scores)
output = torch.matmul(scores,z)
return output