Attention for image classification

for an input image of size, 3x28x28

inp = torch.randn(1, 3, 28, 28)
x = nn.MultiheadAttention(28, 2)
x(inp[0], torch.randn(28, 28), torch.randn(28, 28))[0].shape


torch.Size([3, 28, 28])


x(inp[0], torch.randn(28, 28), torch.randn(28, 28))[1].shape


torch.Size([28, 3, 1])

what is the correct way of using MultiHeadAttention for images?

does this multiheadattention work something liike paying attention only on a part of image?

I think this function is for the sequence models, and not for image classification. Based on the paper Attention is all you need.

but attention is applicable to images also, like here,

is nn.MultiheadAttention applicable for images?

I want neural network to focus on part of image, I think attention is doing something similar with text, focusing on some part of text.

But nn.MultiheadAttention the is just for sequence models.

I guess you meant some techniques to apply attention to convolution networks.

Attention is like a new wave for convnets.
You can do it either by changing the architecture or changing the loss function or both.

The problem with convolution is that it has local receptive field.
Opposite to that fc layers have the global receptive field. So the idea to combine that using SE blocks is here.

Also, there is an idea to concat convolution and attentional featuremaps here.

Same ideas but for GAN, you can find in SAGAN.

But these papers I think haven’t been implemented in PyTorch yet. It may take about few months for the good paper to be inside the PyTorch.


This paper can also be useful, depending on what you are trying to do.


What is dv, dk here?

We now formally describe our proposed Attention Augmentation method.
We use the following naming conventions: H, W and Fin refer to the height, 
width and number
of input filters of an activation map. Nh, dv and dk respectively refer the 
number of heads, the depth of values and the
depth of queries and keys in multihead-attention (MHA).
We further assume that Nh divides dv and dk evenly and
denote dhv and dhk the depth of values and queries/keys per
attention head.

if I do nn.MultiheadAttention(28, 2), then Nh = 2, but, dv, dk, dhv, dhk = ???

If I want to transform an image to another image, then

transformer_model = nn.Transformer(img_size, n_heads)
transformer_model(source_image, target_image)

is this the correct way to use nn.Transformer for images?

Were you able to figure out how to use MHA for images? Would appreciate pointers if you have!