RuntimeError: shape '[128, 3, 9, 16, 9, 16]' is invalid for input of size 9586176

I am following a tutorial and trying to extract image descriptors using a pre-trained Vision Transformer (vit_b_16). However, when I run the code I get this error: RuntimeError: shape ‘[128, 3, 9, 16, 9, 16]’ is invalid for input of size 9586176.

The code looks like this:

net = ViT(model_kwargs={
            'embed_dim': 256,
            'hidden_dim': 512,
            'num_heads': 8,
            'num_layers': 6,
            'patch_size': 16,
            'num_channels': 3,
            'num_patches': 196,
                   'num_classes': num_classes,
                   'dropout': 0.2
                            },
			lr=3e-4)
def img_to_patch(x, patch_size, flatten_channels=True):

        B, C, H, W = x.shape
        print(x.shape)
        x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
        x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
        x = x.flatten(1,2)              # [B, H'*W', C, p_H, p_W]
        if flatten_channels:
            x = x.flatten(2,4)          # [B, H'*W', C*p_H*p_W]
        print(x.shape)
        return x

class AttentionBlock(nn.Module):

    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of input and attention feature vectors
            hidden_dim - Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than embed_dim)
            num_heads - Number of heads to use in the Multi-Head Attention block
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x


class VisionTransformer(nn.Module):
    
    def __init__(self, embed_dim, hidden_dim, num_channels, num_heads, num_layers, num_classes, patch_size, num_patches, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of the input feature vectors to the Transformer
            hidden_dim - Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels - Number of channels of the input (3 for RGB)
            num_heads - Number of heads to use in the Multi-Head Attention block
            num_layers - Number of layers to use in the Transformer
            num_classes - Number of classes to predict
            patch_size - Number of pixels that the patches have per dimension
            num_patches - Maximum number of patches an image can have
            dropout - Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size

        # Layers/Networks
        self.input_layer = nn.Linear(num_channels*(patch_size**2), embed_dim)
        self.transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers)])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1,1+num_patches,embed_dim))

    def forward(self, x):
        # Preprocess input
        x = img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)

        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:,:T+1]

        # Apply Transforrmer
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        # Perform classification prediction
        cls = x[0]
        out = self.mlp_head(cls)
        return out

class ViT(pl.LightningModule):

    def __init__(self, model_kwargs, lr):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)
       # self.example_input_array = next(iter(train_loader))[0]

    def forward(self, x):
        return self.model(x)
def augmentation(key, imsize=224):
    '''Using ImageNet statistics for normalization.
    '''

    augment_dict = {
        "augment_inference":
			transforms.Compose([
                transforms.RandomResizedCrop(imsize, scale=(0.7,1.0),ratio = (0.99,1/0.99)),
				transforms.ToTensor(),
				transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
				])
    }

    return augment_dict[key]

I can’t tell where I messed up the dimensions. Any help is greatly appreciated.

Edit: I think it has to do with the reshaping in img_to_patch(), it’s a problem when the height and width are not multiples of patch_size. But how do I fix it?
Printing the x.shape at the beginning and end of the forward function:
forward beg: torch.Size([128, 3, 224, 224])
forward end: torch.Size([197, 128, 256])
forward beg: torch.Size([128, 3, 158, 158])
And after this it messes up.

Are the augmentations being applied for every batch here? It looks like the desired output size is 224x224 (and what is specified in the augmentations but it looks like in the second forward the input size is 158x158.

So regarding the augmentations I do this:

extraction_transform = augmentation("augment_inference")
train_dataset = MET_database(root = train_root,mini = args.mini,transform = extraction_transform,im_root = args.im_root)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=128,shuffle=False,num_workers=4,pin_memory=True)

Does this mean the augmentations are applied to every batch?

Could you check the shape of the input just before it gets passed the model? If it isn’t 224x224, then could you check what is happening in MET_database? I’m assuming it calls into the transform function.

I have this code to extract the descriptors:

train_descr = extract_embeddings(net,train_loader,ms = scales,msp = 1.0,print_freq=20000)
def extract_embeddings(net,dataloader,ms=[1],msp=1,print_freq=None,verbose = False):

    if verbose:

        if len(ms) == 1:
            print("Singlescale extraction")
        else:
            print("Multiscale extraction at scales: " + str(ms))
    
    net.eval()
    
    with torch.no_grad():

        vecs = np.zeros((768, len(dataloader.dataset))) #net.meta['outputdim']
    
        for i,input in enumerate(dataloader):

            if len(ms) == 1 and ms[0] == 1:
                vecs[:, i] = extract_ss(net,input[0]) #.cuda()
            
            else:
                vecs[:, i] = extract_ms(net,input[0], ms, msp) #.cuda()

            if print_freq is not None:
                #if i%print_freq == 0:
                print("image: "+str(i))

    return vecs.T

def extract_ss(net, input):

    return net(input).cpu().data.squeeze()

def extract_ms(net, input, ms, msp):
    
    v = torch.zeros(128,224408) #num_classes=224408
    print(input.shape)
    
    for s in ms:

        if s == 1:
            input_t = input.clone()
        
        else:
            input_t = nn.functional.interpolate(input, scale_factor=s, mode='bilinear', align_corners=False)

        v += net(input_t).pow(msp).cpu().data.squeeze()
        
    v /= len(ms)
    v = v.pow(1./msp)
    v /= v.norm()

    return v

If I print input.shape in extract_ms I get torch.Size([128, 3, 224, 224])

Edit: If I print input_t.shape within extract_ms() I get torch.Size([128, 3, 158, 158])

And the MET_database class looks like this:

class MET_database(VisionDataset):

    def __init__(
            self,
            root: str = ".",
            mini: bool = False,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            loader: Callable[[str], Any] = default_loader,
            is_valid_file: Optional[Callable[[str], bool]] = None,
            im_root = None
    ) -> None:
        super().__init__(root, transform=transform,
                            target_transform=target_transform)

        fn = "MET_database.json"

        if mini:
            fn = "mini_"+fn

        with open(os.path.join(self.root, fn)) as f:
            data = json.load(f)

        samples = []
        targets = []

        for e in data:
            samples.append(e['path'])
            targets.append(int(e['id']))

        self.loader = loader
        self.samples = samples
        self.targets = targets

        assert len(self.samples) == len(self.targets)

        self.im_root = im_root


    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        
        if self.im_root is not None:
            path = os.path.join(self.im_root, "images/" + self.samples[index])            

        else:
            path = os.path.join(os.path.dirname(self.root), "images/" + self.samples[index])

        target = self.targets[index]
        sample = self.loader(path)
        
        if self.transform is not None:
            sample = self.transform(sample)
        
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target


    def __len__(self) -> int:
        
        return len(self.samples)

Yes the interpolation would break the img_to_patch function if the reshapes are no longer valid. Could you round the scale factor resolutions to the nearest multiple of 16 (e.g., 160x160) and use that size directly (interpolate can accept a size/resolution OR a scale factor: torch.nn.functional.interpolate — PyTorch 1.11.0 documentation)?

Thank you, I got rid of that error now! But I’m having some trouble with vecs from extract_embeddings() and v from extract_ms() because I’m not sure what shape to initialize them with, I keep changing them but I get errors like: ValueError: could not broadcast input array from shape (128,224408) into shape (768,) and RuntimeError: The size of tensor a (768) must match the size of tensor b (224408) at non-singleton dimension 1

You can simply print the output shape of the model to know what v should be initialized to. I would check that this shape matches your expectations of something like [batch_size, embedding_dim].