AssertionError: Expected (batch_size, seq_length, hidden_dim) got torch.Size([1, 768, 24, 31])

Hi. I used the following code to extract descriptors from images using CNNs (resnet, efficientnet) and tried to do the same thing with ViT but I’m getting the following error: AssertionError: Expected (batch_size, seq_length, hidden_dim) got torch.Size([1, 768, 24, 31]). If someone could help me modify the code in order to make it work that would be great.

extract_descriptors.py

    train_dataset = MET_database(root = train_root,mini = args.mini,transform = extraction_transform,im_root = args.im_root)
	test_dataset = MET_queries(root = query_root,test = True,transform = extraction_transform,im_root = args.im_root)
	val_dataset = MET_queries(root = query_root,transform = extraction_transform,im_root = args.im_root)
    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=1,shuffle=False,num_workers=8,pin_memory=True)
	test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=1,shuffle=False,num_workers=8,pin_memory=True)
	val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=1,shuffle=False,num_workers=8,pin_memory=True)

	#initialization of the global descriptor extractor model

	if network_variant == 'r18I':
		net = Embedder("resnet18", gem_p = 3.0, pretrained_flag = True)
	elif network_variant == 'eb4':
		net = Embedder("efficientnet_b4", gem_p = 3.0, pretrained_flag = True)
	elif network_variant == 'vb16':
		net = Embedder("vit_b_16", gem_p = 3.0, pretrained_flag = True)
	else:
		raise ValueError('Unsupported  architecture: {}!'.format(network_variant))

	if args.ms:
		#multi-scale case
		scales = [1, 1/np.sqrt(2), 1/2]

	else:
		#single-scale case
		scales = [1]

	print("Starting the extraction of the descriptors")

	train_descr = extract_embeddings(net,train_loader,ms = scales,msp = 1.0,print_freq=20000)
	test_descr = extract_embeddings(net,test_loader,ms = scales,msp = 1.0,print_freq=5000)
	val_descr = extract_embeddings(net,val_loader,ms = scales,msp = 1.0,print_freq=1000)

backbone.py

class Embedder(nn.Module):

	def __init__(self,architecture,gem_p = 3,pretrained_flag = True,projector = False,init_projector = None):

		super(Embedder, self).__init__()

		#load the base model from PyTorch's pretrained models (imagenet pretrained)
		network = getattr(models,architecture)(pretrained=pretrained_flag)

		#keep only the convolutional layers, ends with relu to get non-negative descriptors
		if architecture.startswith('resnet'):
			self.backbone = nn.Sequential(*list(network.children())[:-2])
		elif architecture.startswith('efficientnet'):
			self.backbone = nn.Sequential(*list(network.children())[:-2])
		elif architecture.startswith('vit'):
			self.backbone = nn.Sequential(*list(network.children())[:-1])

		#spatial pooling layer
		self.pool = GeM(p = gem_p)

		#normalize on the unit-hypershpere
		self.norm = F.normalize

		#information about the network
		self.meta = {
			'architecture' : architecture, 
			'pooling' : "gem",
			'mean' : [0.485, 0.456, 0.406], #imagenet statistics for imagenet pretrained models
			'std' : [0.229, 0.224, 0.225],
			'outputdim' : OUTPUT_DIM[architecture],
		}

		self.projector = None


	def forward(self, img):
		'''
		Output has shape: batch size x descriptor dimension
		'''

		x = self.norm(self.pool(self.backbone(img))).squeeze(-1).squeeze(-1)

		return x

def extract_ss(net, input):

	return net(input).cpu().data.squeeze()

def extract_ms(net, input, ms, msp):
	
	v = torch.zeros(net.meta['outputdim'])
	
	for s in ms:

		if s == 1:
			input_t = input.clone()
		
		else:    
			input_t = nn.functional.interpolate(input, scale_factor=s, mode='bilinear', align_corners=False)
		
		v += net(input_t).pow(msp).cpu().data.squeeze()
		
	v /= len(ms)
	v = v.pow(1./msp)
	v /= v.norm()

	return v

utils.py

def extract_embeddings(net,dataloader,ms=[1],msp=1,print_freq=None,verbose = False):

    if verbose:

        if len(ms) == 1:
            print("Singlescale extraction")
        else:
            print("Multiscale extraction at scales: " + str(ms))
    
    net.eval()
    
    with torch.no_grad():

        vecs = np.zeros((net.meta['outputdim'], len(dataloader.dataset)))
    
        for i,input in enumerate(dataloader):

            if len(ms) == 1 and ms[0] == 1:
                vecs[:, i] = extract_ss(net,input[0]) #.cuda()
            
            else:
                vecs[:, i] = extract_ms(net,input[0], ms, msp) #.cuda()

            if print_freq is not None:
                #if i%print_freq == 0:
                print("image: "+str(i))

    return vecs.T

If you use the pretrained vit_b_16, then it expects that you enter the correct size of the image (3 x 224 x 224), but it looks like you are giving a (3 x 384 x 496) image.

Then by using nn.Sequential to get the structure of vit_b_16, it does not do the correct preprocessing of the image.

The image must be flattened into patches and added the cls (class) token before entering the encoder.

Here is the source code for vit_b_16.

I followed a tutorial and I modified the backbone.py to look like this:

def img_to_patch(x, patch_size, flatten_channels=True):

        B, C, H, W = x.shape
        x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
        x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
        x = x.flatten(1,2)              # [B, H'*W', C, p_H, p_W]
        if flatten_channels:
            x = x.flatten(2,4)          # [B, H'*W', C*p_H*p_W]
        return x

class AttentionBlock(nn.Module):

    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of input and attention feature vectors
            hidden_dim - Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than embed_dim)
            num_heads - Number of heads to use in the Multi-Head Attention block
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x


class VisionTransformer(nn.Module):
    
    def __init__(self, embed_dim, hidden_dim, num_channels, num_heads, num_layers, num_classes, patch_size, num_patches, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of the input feature vectors to the Transformer
            hidden_dim - Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels - Number of channels of the input (3 for RGB)
            num_heads - Number of heads to use in the Multi-Head Attention block
            num_layers - Number of layers to use in the Transformer
            num_classes - Number of classes to predict
            patch_size - Number of pixels that the patches have per dimension
            num_patches - Maximum number of patches an image can have
            dropout - Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size

        # Layers/Networks
        self.input_layer = nn.Linear(num_channels*(patch_size**2), embed_dim)
        self.transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers)])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1,1+num_patches,embed_dim))

    def forward(self, x):
        # Preprocess input
        print(x.shape)
        x = img_to_patch(x, self.patch_size)
        print(x.shape)
        B, T, _ = x.shape
        x = self.input_layer(x)

        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:,:T+1]

        # Apply Transforrmer
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        # Perform classification prediction
        cls = x[0]
        out = self.mlp_head(cls)
        return out

class ViT(pl.LightningModule):

    def __init__(self, model_kwargs, lr):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)
       # self.example_input_array = next(iter(train_loader))[0]

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150], gamma=0.1)
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode="train"):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log(f'{mode}_loss', loss)
        self.log(f'{mode}_acc', acc)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="test")

and in extract_descriptors.py I now have this:

	if network_variant == 'vb16':
		net = ViT(model_kwargs={
                                'embed_dim': 256,
                                'hidden_dim': 512,
                                'num_heads': 8,
                                'num_layers': 6,
                                'patch_size': 4,
                                'num_channels': 3,
                                'num_patches': 64,
                                'num_classes': num_classes,
                                'dropout': 0.2
                            },
							lr=3e-4)

But now I’m getting the error: RuntimeError: shape ‘[128, 3, 5, 4, 5, 4]’ is invalid for input of size 185856.

This is my first time working with PyTorch and Vision Transformers so I’m sorry if I require more ‘dumbed-down’ instructions :sweat_smile:

Hi, sorry for the late reply.

I will recommend you another approach instead of fixing this, later if you want we can see how to fix this.

The problem

So, the problem is that when you use nn.Sequential you get the nn.Modules used inside of the architecture, but all functional API is gone. As you saw on the source code that I linked you, there are some stuff going on inside the forward method of ViT that is not done (flattening + cls token).

Possible solutions

As you have already tried, there are some ways of modifying the model to fit your needs, however it is sometimes problematic. Here are some possible solutions:

  • Rewrite the code (more or less what you are doing)
  • Add forward hooks (a little complicated but it works)
  • and
  • Use torch.fx feature extractor (I think this is my favorite)

Feature Extraction in TorchVision using Torch FX

In order to understan well how, why and what it’s good for, I really recommend this post. I will only explain the bare minimum for it to work.

The network

Here we do the same thing that you did before to load the network

network = getattr(torchvision.models,"vit_b_16")(pretrained=True)

The Feature Extractor

This is where the magic happens

from torchvision.models.feature_extraction import create_feature_extractor
feature_extractor = create_feature_extractor(network, return_nodes=['getitem_5'])

But now you are asking: "Matias, why did you use ‘getitem_5?’ and that is a great question.

Well you need to specify the return_nodes. This means you can specify a list with many places where you want to interrupt your model and get the output for a given input. For this you need to know the exact node_name.

Graph Node Names

In order to know how the node_names are defined, you can use the following code

print(torchvision.models.feature_extraction.get_graph_node_names(network))

This will give you a long list of every node name like this one (these are just the last nodes)
image

So, as you can see, the last node before head is called getitem_5. Here is where we want to get the information.

Using the feature extractor

putting all of it together it would look something like this

from torchvision.models.feature_extraction import create_feature_extractor

network = getattr(torchvision.models,"vit_b_16")(pretrained=True)
feature_extractor = create_feature_extractor(network, return_nodes=['getitem_5'])

img = torch.rand(1, 3, 224, 224)
print(feature_extractor(img)['getitem_5'])

If you try this, you should not get any error. The output shape should be torch.Size([1, 768]). You can now feed this to a MLP or whatever you want to do with it.

Hope this helps :smile: and as I said, if you want to look at your approach anyways, then I can look at what went wrong, but I think this approach is easier and less prone to error.

Thank you so much for this! But how do I feed the feature_extractor my whole training data? I tried to implement it inside extract_embeddings() like this:

def extract_embeddings(net,dataloader,ms=[1],msp=1,print_freq=None,verbose = False):

    feature_extractor = create_feature_extractor(net, return_nodes=['getitem_5'])

    if verbose:

        if len(ms) == 1:
            print("Singlescale extraction")
        else:
            print("Multiscale extraction at scales: " + str(ms))
    
    net.eval()
    
    with torch.no_grad():

        vecs = np.zeros((768, len(dataloader.dataset)))
    
        for i,input in enumerate(dataloader):

            if len(ms) == 1 and ms[0] == 1:
                vecs[:, i] = feature_extractor(net,input[0]) #.cuda()
            
            else:
                #vecs[:, i] = extract_ms(net,input[0], ms, msp) #.cuda()
                vecs[:, i] = feature_extractor(input[0])

            if print_freq is not None:
                print("image: "+str(i))

    return vecs.T

But I get the error from line vecs[:, i] = feature_extractor(input[0]):
TypeError: float() argument must be a string or a number, not ‘dict’

You can feed it the whole batch like you would normally feed a batch to a model.

You just need to extract the output you need from what you get from the feature_extractor

batch = 10
channels = 3
h = 224
w = 224
imgs = torch.rand(batch, channels, h, w)

# This will return a dictionary
feature_extractor(imgs)

# This will get what you actually want from that dictionary
feature_extractor(imgs)['getitem_5']

# The output shape is
# torch.Size([10, 768])
# Meaning you have 10 images, all with 768 hidden_dim

So the error you get is because you are missing [getitem_5] after passing the images through the feature_extractor.

1 Like

It’s working now, thank you so much! So just to check I’m doing everything right:
I initialize the model like this:

net = Embedder("vit_b_16", pretrained_flag = True)

In the Embedder class I have this:

network = getattr(models,architecture)(pretrained=pretrained_flag)

#keep only the convolutional layers, ends with relu to get non-negative descriptors
if architecture.startswith('resnet'):
    self.backbone = nn.Sequential(*list(network.children())[:-2])
elif architecture.startswith('efficientnet'):
    self.backbone = nn.Sequential(*list(network.children())[:-2])
elif architecture.startswith('vit'):
      self.backbone = network

The forward looks like this:

    def forward(self, img):
        '''
        Output has shape: batch size x descriptor dimension
        '''
        x = self.backbone(img).squeeze(-1).squeeze(-1)

        return x

And the extract_embeddings() looks like this:

def extract_embeddings(net,dataloader,ms=[1],msp=1,print_freq=None,verbose = False):

    print(get_graph_node_names(net))
    feature_extractor = create_feature_extractor(net, return_nodes=['backbone.getitem_5'])

    if verbose:

        if len(ms) == 1:
            print("Singlescale extraction")
        else:
            print("Multiscale extraction at scales: " + str(ms))
    
    net.eval()
    
    with torch.no_grad():

        vecs = np.zeros((768, len(dataloader.dataset)))
    
        for i,input in enumerate(dataloader):

            if len(ms) == 1 and ms[0] == 1:
                vecs[:, i] = feature_extractor(net,input[0]) #.cuda()
            
            else:
                #vecs[:, i] = extract_ms(net,input[0], ms, msp) #.cuda()
                vecs[:, i] = feature_extractor(input[0])['backbone.getitem_5']

            if print_freq is not None:
                #if i%print_freq == 0:
                print("image: "+str(i))

    return vecs.T

Does this look right?

1 Like

Yeah, it looks good.

I hope it works correctly :smile: