TypeError: avg_pool2d(): argument 'kernel_size' (position 2) must be tuple of ints, not tuple

Hi. I’m trying to implement a pre-trained Vision Transformer and perform gem pooling but I get this error.

The code is:

class GeM(nn.Module):
    
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)
        
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

def gem(x, p=3, eps=1e-6):

    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)

class Embedder(nn.Module):

    def __init__(self,architecture,gem_p=3,pretrained_flag = True):

        super(Embedder, self).__init__()

        #load the base model from PyTorch's pretrained models (imagenet pretrained)
        network = getattr(models,architecture)(pretrained=pretrained_flag)
        
        if architecture.startswith('vit'):
            self.backbone = network

        #spatial pooling layer
        self.pool = GeM(p = gem_p)

        #normalize on the unit-hypershpere
        self.norm = F.normalize

    def forward(self, img):
        '''
        Output has shape: batch size x descriptor dimension
        '''
        x = self.norm(self.pool(self.backbone(img))).squeeze(-1).squeeze(-1)

        return x

def extract_embeddings(net,dataloader,ms=[1],msp=1,print_freq=None,verbose = False):

    feature_extractor = create_feature_extractor(net, return_nodes=['backbone.getitem_5'])

    if verbose:

        if len(ms) == 1:
            print("Singlescale extraction")
        else:
            print("Multiscale extraction at scales: " + str(ms))
    
    net.eval()
    
    with torch.no_grad():

        vecs = np.zeros((768, len(dataloader.dataset))) #768 -> vb16; 1024 -> vl16
    
        for i,input in enumerate(dataloader):

            if len(ms) == 1 and ms[0] == 1:
                vecs[:, i] = feature_extractor(net,input[0])['backbone.getitem_5'] #.cuda()
            
            else:
                vecs[:, i] = feature_extractor(input[0])['backbone.getitem_5'] 

            if print_freq is not None:
                print("image: "+str(i))

    return vecs.T

But I get the error: TypeError: avg_pool2d(): argument ‘kernel_size’ (position 2) must be tuple of ints, not tuple. Does anyone know how I can fix it?

The error message sounds a bit strange but I guess the dtype of the kernel size might be a float. Could you check it and transform the values to int?

If I try this: print(x.size(-2).dtype) it prints Proxy(getattr_16). Can I just put in (224,224) as the kernel size since those are the image dimensions?

I see that you have your function extract_embeddings but I cannot see where you use it.

In your __init__ method inside your Embedder class you define your backbone as being this architecture, but then in your forward method you use the full model to process the image.

If this is really so, then at the end you should have a 1x1000 tensor with the predictions for each class.

Can you print the image of the output from the backbone like this ↓ ?

If I do that it prints this: Proxy(getattr_16).
What I do is I initialize the model:

	if network_variant == 'vb16IN':
		net = Embedder("vit_b_16", pretrained_flag = True)

Then extract the embeddings like this:

train_descr = extract_embeddings(net,train_loader,ms = scales,msp = 1.0,print_freq=20000)

I tried to follow the feature extraction instruction here AssertionError: Expected (batch_size, seq_length, hidden_dim) got torch.Size([1, 768, 24, 31]) - #4 by Matias_Vasquez
But I’m not sure I’m doing it right

Shape

You get this ↓ by printing the shape of x?
Can you print more information about x? Like the type or what it actually is?

print(type(x))
print(x)
print(x.shape)

Feature Extractor

What this feature extractor does, is run the input through the entire architecture and returns you a dictionary with the features that you request. In the example from that other response you can see that give the input image to the feature extractor and NOT to the network.

Maybe an easier way to put this by using another name. Here I only changed from feature_extractor to network_using_features.

# This is the same code as I used in the other 
from torchvision.models.feature_extraction import create_feature_extractor

network = getattr(torchvision.models,"vit_b_16")(pretrained=True)
network_using_features = create_feature_extractor(network, return_nodes=['getitem_5'])

img = torch.rand(1, 3, 224, 224)
print(network_using_features(img)['getitem_5'])

What do you want to do?

But the main question is what you want to do with this architecture?
I think it would really help if you explain what you are trying to do in simple words.

If you only want to fine-tune the Vision Transformer to classify images from your dataset, then we are overcomplicating stuff here.

Also, understanding how the architecture works and what it does to the image will really help you decide what/where/how you want to use the data coming from this architecture.

There are many great videos on youtube that explain this architecture.

So what I want to do is get the output of getitem_5 and then apply gem pooling and normalization on it, is that possible? The bigger picture is: I want to extract the image descriptors using a Vision Transformer and then feed the descriptors into a kNN classifier.
So how can I initialize the backbone to only be up until getitem_5? Because with ResNet I kept only the layers I wanted like this:

self.backbone = nn.Sequential(*list(network.children())[:-2])

But it’s not possible to do it like this with ViT

I tried to fix it. Can you take a look at this code:

class GeM(nn.Module):
    
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)
        
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

def gem(x, p=3, eps=1e-6):
    print(x.size(-1))
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)

class Embedder(nn.Module):

    def __init__(self,architecture,gem_p = 3,pretrained_flag = True):

        super(Embedder, self).__init__()

        #load the base model from PyTorch's pretrained models (imagenet pretrained)
        network = getattr(models,architecture)(pretrained=pretrained_flag)

        #keep only the convolutional layers, ends with relu to get non-negative descriptors
        if architecture.startswith('resnet'):
            self.backbone = nn.Sequential(*list(network.children())[:-2])
        elif architecture.startswith('efficientnet'):
            self.backbone = nn.Sequential(*list(network.children())[:-2])
        elif architecture.startswith('vit'):
            self.backbone = create_feature_extractor(network, return_nodes=['getitem_5'])

        #spatial pooling layer
        self.pool = GeM(p = gem_p)

        #normalize on the unit-hypershpere
        self.norm = F.normalize

        #information about the network
        self.meta = {
            'architecture' : architecture, 
            'pooling' : "gem",
            'mean' : [0.485, 0.456, 0.406], #imagenet statistics for imagenet pretrained models
            'std' : [0.229, 0.224, 0.225],
            'outputdim' : OUTPUT_DIM[architecture],
        }

    def forward(self, img):
        '''
        Output has shape: batch size x descriptor dimension
        '''
        x = self.norm(self.pool(self.backbone(img)['getitem_5'])).squeeze(-1).squeeze(-1)

        return x

def extract_ss(net, input):

    return net(input).cpu().data.squeeze()

def extract_ms(net, input, ms, msp):
    
    v = torch.zeros(net.meta['outputdim'])
    
    for s in ms:

        if s == 1:
            input_t = input.clone()
        
        else:    
            input_t = nn.functional.interpolate(input, scale_factor=s, mode='bilinear', align_corners=False)
        
        v += net(input_t).pow(msp).cpu().data.squeeze()
        
    v /= len(ms)
    v = v.pow(1./msp)
    v /= v.norm()

    return v

def extract_embeddings(net,dataloader,ms=[1],msp=1,print_freq=None,verbose = False):

    if verbose:

        if len(ms) == 1:
            print("Singlescale extraction")
        else:
            print("Multiscale extraction at scales: " + str(ms))
    
    net.eval()
    
    with torch.no_grad():

        vecs = np.zeros((net.meta['outputdim'], len(dataloader.dataset)))
    
        for i,input in enumerate(dataloader):

            if len(ms) == 1 and ms[0] == 1:
                vecs[:, i] = extract_ss(net,input[0]) #.cuda()
            
            else:
                vecs[:, i] = extract_ms(net,input[0], ms, msp) #.cuda()

            if print_freq is not None:
                if i%print_freq == 0:
                     print("image: "+str(i))

    return vecs.T

But now I get the error:

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got -3)

from line:

return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)

New approach

So what you did now looks way better. Now you are really using the feature extractor to get some information from the architecture and use it. (and not the whole architecture).

How you may have seen, there are still some problems.

In order to fix them we need to truly understand what is happening with ViT and what we want to do. (That is why I asked what you intended to do with this.)

If your approach is to simply use ViT and classify directly with it, then there is a simpler approach. If as you mentioned, you want to use the “image features” and feed them to another architecture, then we need to understand ViT.

  • ViT as classifier for N classes
  • Use “image features” and feed them to something else

ViT as classifier for N classes

Here there is not much to understand. We have our architecture and in the end there is a head that classifies the image into N classes. To change this we only need to change the head to meet our needs and fine-tune the model.

For the torchvision implementation of vit_b_16, we can access the head like this

import torchvision

model = torchvision.models.vit_b_16(pretrained=True)
print(model.heads)

# Output:
#Sequential(
#  (head): Linear(in_features=768, out_features=1000, bias=True)
#)

So if we wanted to change this it would be very straighforward.

With this simple code we get our model with the right amount of classes and we can fine-tune it to classify our specific dataset.

import torch
import torchvision

N_CLASSES = 2

model = torchvision.models.vit_b_16(pretrained=True)
model.heads.head = torch.nn.Linear(768, N_CLASSES)

Use “image features” and feed them to something else

If we want to do this approach, then we need to first understand ViT.

Understanding ViT

So these are the steps that we will look into:

  • Making patches
  • Flatten the patches
  • Class token
  • Positional Embedding
  • Encoder
    • Block
  • Head

  • Making patches

On the lower left corner of the image you can see how the image is divided into patches. For the architecture we are using (vit_b_16) each patch has 3x16x16 pixels. For an image with 224 pixels, we get 14x14 patches (14*16=224). This is why the input image has to have this size.

If we have a batch (B) of one image we could write it like this B x Patch_j x Patch_i x C x H x W = 1 x 14 x 14 x 3 x 16 x 16.

This means we have one image. This image has 14 patches in the y direction (j) and 14 patches in the x direction (i). Each of these patches have 3 channels. These patches also have 16 by 16 pixels in the x and y directions.

  • Flatten the patches

In the image above, we can see there is a pink box with the title Linear Projection of Flattened Patches. What we are making is rearanging these patches. So before we had this structure

B x Patch_j x Patch_i x C x H x W = 1 x 14 x 14 x 3 x 16 x 16

now we merge the 14x14 patches into one dimension. We also merge CxHxW into one direction.

B x Patches x Pixels = 1 x 14*14 x 3*16*16 = 1 x 196 x 768

This is also known as Batch x Seq_Length x Hidden_Dim.

So if we understand this, we can see that here the image has another meaning.

  • Class token

After this pink box on the architecture, there are several pink boxes with numbers and an extra box with an asterix on the left. As described on the image, this asterix is the Extra learnable [class] embedding. Here is where the results of the image transformer will be.

The size of this class token is B x 1 x Hidden_Dim = 1 x 1 x 768 for our case.

This is then prepended at the beggining of our representation of the image on the Seq_Length dim.

[cls]; img -> B x Seq_length + 1 x Hidden_Dim = 1 x 197 x 768.

Now Seq_length = 197. Here we have our 14x14 patches plus one for our [cls] token.

  • Positional Embedding

Also on the image we can see that a positional embedding is added. This is in order to see how the position of the patches relate to eachother.

The Dimentions however, stay the same.

B x Seq_Length x Hidden_Dim = 1 x 197 x 768

  • Encoder

Now come the fun part. On the right of the image is the architecture of ONE encoder block. There are L blocks in the ViT architecture. (L=12 in ours). Meaning the image is fed to the first encoder block and the output of the first comes to the second and third sequentially.

All of the blocks have the same architecture, meaning the size of the output has to be the same as input in order to be fed to the next block.

Size_in = B x Seq_Length x Hidden_Dim = 1 x 197 x 768
Size_out = B x Seq_Length x Hidden_Dim = 1 x 197 x 768

  • Block

If we look at how the block architecture is built, we can see some skip layers, batch normalizations and a multi-layer perceptron (MLP) in the end. There is also a Multi-Head Attention Layer (MHA).

The MHA consists of multiple parallel scaled dot product attention mechanisms with learnable parameters.

One Scaled dot product attention looks like this

Described by the following equation

grafik

There is too much to unpack about this equation, but the important thing is that the pixels are attending to themselves. And HERE might be a good place to get the image features from.

To see how, see below in the Image Features section.

  • Head

Up until now our data has the following format

B x CLS_Token+Seq_Length x Hidden_dim = 1 x 197 x 768.

We said that the CLS Token is where the classification will be done. So now we only take the CLS token and do not care for the rest.

B x CLS_Token x Seq_Length = 1 x 1 x 768

As we saw on the beggining of this post, the head consists (in this case) of a Linear layer that has 768 in features and 1000 features. This means that we feed our data to this Linear layer and we classify between 1000 classes.

On the image with the ViT architecture, this is represented by the yellow box with MLP Head written on it.

Image Features

Here is an implementation of tensorflow to get the attention map.

Here is a video + code on how to get them for pytorch.

Here is also a paper that might interest you.

These methods however, require that you either switch to TF or rewrite the ViT to access the attention map.

With feature extractor you can get intermediate steps (as we have already done by taking the getitem_5, which is almost the last step in the architecture).

If we inspect the graph_node_names as we did before to get the getitem_5 name, we can see that there is a encoder.layers.encoder_layer_11.self_attention. But the result is after doing the full Multi-Head Self-Attention.

We want an intermediate result.
So for this approach you could do something like this. (I ran this in a python notebook to see all of the heads.)

import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import requests
from io import BytesIO
import torchvision.transforms as T
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names


url = "https://www.thesprucepets.com/thmb/k3NXIqobAKvxoQ2ozGcwPxzIkpI=/3300x1856/smart/filters:no_upscale()/most-obedient-dog-breeds-4796922-hero-4440a0ccec0e42c98c5e58821fc9f165.jpg"
response = requests.get(url)
img = Image.open(BytesIO(response.content))

img = T.Resize((224,224))(T.ToTensor()(img))
plt.imshow(img.permute(1, 2, 0))
plt.show()


model = torchvision.models.vit_b_16(pretrained=True)
keys = ['encoder.layers.encoder_layer_11.ln']
feature_extractor = create_feature_extractor(model, return_nodes=keys)
feature_extractor.eval()


out = feature_extractor(img.unsqueeze(0))

x = out[keys[0]]
x_, attn = model.encoder.layers[-1].self_attention(x, x, x, need_weights=True, average_attn_weights=False)

print(x_.shape)
print(attn.shape)

for i in range(12):
    sns.heatmap(attn[0, i, 0, 1:].view(14, 14).detach().numpy())
    plt.show()

The output will be the 12 self attention heads ploted as heatmaps. As you will see, for this particular example, many will mean nothing to you but one of the heads looks like this.
grafik
grafik

If you set average_attn_weigths=True you will get the average of all 12 attention heads and will also mean nothing to us.

But if you also get the ‘heads’ key and print the predicted class, you will see that the predicted class is correct (208=Labrador retriever). So it means that it is working.
keys = ['encoder.layers.encoder_layer_11.ln', 'heads']
print(out['heads'].argmax(dim=1))

You could do something like this and use this self attention given by the ViT, but you need to understand what is happening and what you want to use and how.

Also, this will only be a suggestion of where important features might be. But these alone may mean nothing when feeding them to another architecture. So you might want to expriment by scaling them to the actual size of the image and feed both the image AND this heatmap to another architecture.

But these are just suggestions of what you could theretically do.

The most important thing is understanding what it does and how it does it. Then you can decide how to proceede with this information.

Hope this is a bit clearer now.

1 Like

Oh and I forgot to address the error that you mentioned.

As you now know, the getitem_5 stage of the ViT is right before the head. Right after we selected only the CLS Token. Which we mentioned here ↓

If you see, the shape here is Bx1x768. If you try to do avg_pool2d with this tensor, which has this weird shape then it will give an error.

The expected input for avg_pool2d has the shape B x C x H x W.

Thank you so much for the very clear explanation! I’ll try to implement all of this.

1 Like