Hi. I used the following code to extract descriptors from images using CNNs (resnet, efficientnet) and tried to do the same thing with ViT but I’m getting the following error: AssertionError: Expected (batch_size, seq_length, hidden_dim) got torch.Size([1, 768, 24, 31]). If someone could help me modify the code in order to make it work that would be great.
extract_descriptors.py
train_dataset = MET_database(root = train_root,mini = args.mini,transform = extraction_transform,im_root = args.im_root)
test_dataset = MET_queries(root = query_root,test = True,transform = extraction_transform,im_root = args.im_root)
val_dataset = MET_queries(root = query_root,transform = extraction_transform,im_root = args.im_root)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=1,shuffle=False,num_workers=8,pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=1,shuffle=False,num_workers=8,pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=1,shuffle=False,num_workers=8,pin_memory=True)
#initialization of the global descriptor extractor model
if network_variant == 'r18I':
net = Embedder("resnet18", gem_p = 3.0, pretrained_flag = True)
elif network_variant == 'eb4':
net = Embedder("efficientnet_b4", gem_p = 3.0, pretrained_flag = True)
elif network_variant == 'vb16':
net = Embedder("vit_b_16", gem_p = 3.0, pretrained_flag = True)
else:
raise ValueError('Unsupported architecture: {}!'.format(network_variant))
if args.ms:
#multi-scale case
scales = [1, 1/np.sqrt(2), 1/2]
else:
#single-scale case
scales = [1]
print("Starting the extraction of the descriptors")
train_descr = extract_embeddings(net,train_loader,ms = scales,msp = 1.0,print_freq=20000)
test_descr = extract_embeddings(net,test_loader,ms = scales,msp = 1.0,print_freq=5000)
val_descr = extract_embeddings(net,val_loader,ms = scales,msp = 1.0,print_freq=1000)
backbone.py
class Embedder(nn.Module):
def __init__(self,architecture,gem_p = 3,pretrained_flag = True,projector = False,init_projector = None):
super(Embedder, self).__init__()
#load the base model from PyTorch's pretrained models (imagenet pretrained)
network = getattr(models,architecture)(pretrained=pretrained_flag)
#keep only the convolutional layers, ends with relu to get non-negative descriptors
if architecture.startswith('resnet'):
self.backbone = nn.Sequential(*list(network.children())[:-2])
elif architecture.startswith('efficientnet'):
self.backbone = nn.Sequential(*list(network.children())[:-2])
elif architecture.startswith('vit'):
self.backbone = nn.Sequential(*list(network.children())[:-1])
#spatial pooling layer
self.pool = GeM(p = gem_p)
#normalize on the unit-hypershpere
self.norm = F.normalize
#information about the network
self.meta = {
'architecture' : architecture,
'pooling' : "gem",
'mean' : [0.485, 0.456, 0.406], #imagenet statistics for imagenet pretrained models
'std' : [0.229, 0.224, 0.225],
'outputdim' : OUTPUT_DIM[architecture],
}
self.projector = None
def forward(self, img):
'''
Output has shape: batch size x descriptor dimension
'''
x = self.norm(self.pool(self.backbone(img))).squeeze(-1).squeeze(-1)
return x
def extract_ss(net, input):
return net(input).cpu().data.squeeze()
def extract_ms(net, input, ms, msp):
v = torch.zeros(net.meta['outputdim'])
for s in ms:
if s == 1:
input_t = input.clone()
else:
input_t = nn.functional.interpolate(input, scale_factor=s, mode='bilinear', align_corners=False)
v += net(input_t).pow(msp).cpu().data.squeeze()
v /= len(ms)
v = v.pow(1./msp)
v /= v.norm()
return v
utils.py
def extract_embeddings(net,dataloader,ms=[1],msp=1,print_freq=None,verbose = False):
if verbose:
if len(ms) == 1:
print("Singlescale extraction")
else:
print("Multiscale extraction at scales: " + str(ms))
net.eval()
with torch.no_grad():
vecs = np.zeros((net.meta['outputdim'], len(dataloader.dataset)))
for i,input in enumerate(dataloader):
if len(ms) == 1 and ms[0] == 1:
vecs[:, i] = extract_ss(net,input[0]) #.cuda()
else:
vecs[:, i] = extract_ms(net,input[0], ms, msp) #.cuda()
if print_freq is not None:
#if i%print_freq == 0:
print("image: "+str(i))
return vecs.T