I have a scaled_attention function as follows:
def scaled_attention(query, key, value, mask=None, attention_type=0):
qk = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.shape[-1])
if mask is not None: qk = qk.masked_fill(mask == 1, -1e9)
#section 1 for Pytorch forum
#qk = F.softmax(qk, dim=-1)
qk = entmax15(qk, dim=-1)
#qk = sparsemax(qk, dim=-1)
#section 2 for pytorch forum
# if attention_type == 0:
# qk = F.softmax(qk, dim=-1)
# else:
# if attention_type == 1:
# qk = entmax15(qk, dim=-1)
# else:
# if attention_type == 2:
# qk = sparsemax(qk, dim=-1)
return torch.matmul(qk, value)
When i manually set the activation function through section 1 (ignore the argument and uncomment desired activation), the model trains as expected. However when I use the if/else statements, the training results become stagnant and the model doesn’t learn anything.
I’ve tried putting code breaks at the activation statements to make sure they aren’t being called but everything appears to run normally.
Is there any specific reason for this behaviour?