Hey,
I got an unexpected keyword Error in MultiheadAttention.forward
with the following code:
att = torch.nn.modules.activation.MultiheadAttention(embed_dim = 512, num_heads = 8)
x = torch.randn(1,5,512)
att(x,x,x,average_attn_weights=False)
which returned:
forward() got an unexpected keyword argument ‘average_attn_weights’. In the documentation however ‘average_attn_weights’ should be a keyword argument. Does anyone know why this is?
Best, Paul