Trying to understand nn.MultiheadAttention coming from Keras

Hi!

I’m trying to understand the multiheadattention function at pytorch MultiheadAttention — PyTorch 1.8.1 documentation and if I can use it to compute the selfattention as it can be done in the keras implementation MultiHeadAttention layer

With Keras implementation I’m able to run selfattention over a 1D vector the following way:

import tensorflow as tf

layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
input_tensor = tf.keras.Input(shape=[8, 16])
output_tensor = layer(input_tensor, input_tensor)
print(output_tensor.shape)
(None, 8, 16)

I’ve tried to do the same with the pytorch implementation but couldn’t make it work. Could be because of lacking knowledge of the Transformer implementation.

import torch
import torch.nn as nn
multihead_attn = nn.MultiheadAttention(16, 8)
input_tensor = torch.zeros(8,60)
multihead_attn(input_tensor,input_tensor,input_tensor)

With error

-> 4624     tgt_len, bsz, embed_dim = query.size()
   4625     assert embed_dim == embed_dim_to_check
   4626     # allow MHA to have different sizes for the feature dimension

ValueError: not enough values to unpack (expected 3, got 2)

It seems the keras function does some work under the hood to make it easier for the user and perform self-attention when both Query and Value are similar. Is there a way to do the same with pytorch?

Do you know, what happens under the hood for 1D inputs in Keras?
I guess you could mimic this behavior e.g. by using unsqueeze and/or expand on the input tensors, once you know how 1D tensors should be handled.

Hi, thanks for the reply.

I haven’t found any particular code in the keras call function call, compute_attention that makes some underhood 1d transformation (at least that I can identify). This goes a bit beyond my understanding, I havent been able to find any 1D attention examples. Do you know if issues in pytorch repo can be use to ask how to do calculations like this?

I don’t think the GitHub repository is the right place to ask questions about the usage of modules, as it’s intended to create issues in case the framework itself encounters a bug.

For your use case: since you cannot see what Keras is doing in the background, you could try to compare the results between Keras and PyTorch by adding the missing dimensions.
I.e. the PyTorch layer expects e.g. query to have the shape:

query: (L,N,E) where L is the target sequence length, N is the batch size, E is the embedding dimension.

Based on this one dimension is missing in your input (could it be the batch dimension?).
If so, use input_tensor = input_tensor.unsqueeze(1) and pass it to the module.

I think I found the solution, I can use the TransformerEncoderLayer — PyTorch 1.8.1 documentation to perform self attention myself