I am working on a seq2seq model with attention.
When visualizing attention sometimes the attention score is spread across the time axis, and sometimes it clearly focusses on a particular time point (see below).
# toy example
spread_attention = [0.1, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1]
focused_attention = [0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 0.0]
I was wondering if there’s a way to reinforce the focused attention instead of the spread attention. Perhaps as an additional loss term, like a regularization factor.
reg_att = attention_spreadness(attention)
loss = cross_entropry_loss + reg_att
loss.backward()
I was thinking of using a negative L1 or L2 regularization, but they both use the sum of weights, and in attention all weights sum up to 1. Then I thought about using a threshold layer:
def attention_spreadness(attention):
return -torch.nn.functional.threshold(-attention, -0.5, 0).sum()
attention_spreadness(torch.Tensor([0.1, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1]))
# > 1.0
attention_spreadness(torch.Tensor([0.0, 0.0, 0.9, 0.1, 0.0, 0.0, 0.0]))
# > 0.1
In this manner, it basically adds the values for all the scores below a certain threshold, and that can be used as a penalizing term.
Does anyone know better approaches to do this? Perhaps some literature on similar attempts at this?