i am trying to calculate the recall of a clip based model to detect anomalies in the mvtec dataset but this error always shows up
`RuntimeError Traceback (most recent call last)
in <cell line: 4>()
2 prompts = [item for sublist in prompts for item in sublist]
3 analyzer = CLIP_ZSAD_Analyzer(model_id=‘openai/clip-vit-base-patch32’,text_prompts=prompts)
----> 4 pcsn, rcl, f1_max = analyzer.analyse()
4 frames
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
→ 116 return func(*args, **kwargs)
117
118 return decorate_context
in analyse(self)
82 self.f1s.append(metric.compute().item())
83 precision = multiclass_precision(pred_labels, labels,average=“macro”, num_classes=31).item()
—> 84 recal = multiclass_recall(pred_labels, labels, average=“macro”, num_classes=31).item()
85 self.recalls.append(recal)
86 self.precisions.append(precision)
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
→ 116 return func(*args, **kwargs)
117
118 return decorate_context
/usr/local/lib/python3.10/dist-packages/torcheval/metrics/functional/classification/recall.py in multiclass_recall(input, target, num_classes, average)
151 input, target, num_classes, average
152 )
→ 153 return _recall_compute(num_tp, num_labels, num_predictions, average)
154
155
/usr/local/lib/python3.10/dist-packages/torcheval/metrics/functional/classification/recall.py in _recall_compute(num_tp, num_labels, num_predictions, average)
193 num_tp = num_tp[mask]
194
→ 195 recall = num_tp / num_labels
196
197 isnan_class = torch.isnan(recall)
RuntimeError: The size of tensor a (19) must match the size of tensor b (31) at non-singleton dimension 0`
this is my code so far
import torch
from torcheval.metrics import MulticlassF1Score as F1
from torcheval.metrics.functional import multiclass_precision, multiclass_recall
from transformers import CLIPProcessor, CLIPModel
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
import torchvision.transforms as transforms
from PIL import Image
categories = [d for d in os.listdir('/content/dataset') if os.path.isdir(os.path.join('/content/dataset', d))]
class MVTec_Dataset(Dataset):
def __init__(self, root_dir):
self.root = os.path.expanduser(root_dir)
cwd = os.getcwd()
self.images = []
self.ground_truths = {}
self.labels = []
self.tranform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
categories = [os.path.join(self.root, d) for d in os.listdir(self.root) if os.path.isdir(os.path.join(self.root, d))]
for index, dirs in enumerate(categories):
for subdir1 in os.listdir(os.path.join(dirs,'test')):
label=index*2-1 if subdir1 == 'good' else index *2
for file in os.listdir(os.path.join(dirs,'test',subdir1)):
self.images.append({'path':os.path.join(dirs,'test',subdir1,file),
'ground_truth':os.path.join(dirs,'ground_truth', subdir1, file.strip('.png') + '_mask.png') if subdir1 != 'good' else None,
'category': index
})
self.labels.append(label)
def __getitem__(self,idx):
obj = self.images[idx]
img = Image.open(obj['path']).convert('RGB')
img = self.tranform(img)
bnd_box = self.get_bnd_box(obj['ground_truth']) if obj['ground_truth'] is not None else (0, 0, img.shape[1], img.shape[2])
bnd_box = torch.tensor(bnd_box)
label = self.labels[idx]
id = obj['category']
return img,label, bnd_box, id
def get_bnd_box(self,path):
image = Image.open(path)
image = self.tranform(image)
indicies = torch.nonzero(image)
min_y = indicies[:,1].min().item()
min_x = indicies[:,2].min().item()
max_y = indicies[:,1].max().item() + 1
max_x = indicies[:,2].max().item() + 1
bnd = [min_y, min_x, max_y, max_x]
return bnd
def __len__(self):
return len(self.images)
class CLIP_ZSAD_Analyzer:
def __init__(self,model_id,text_prompts,dataset_dir='/content/dataset'):
self.test_data = MVTec_Dataset(dataset_dir)
self.data_loader = DataLoader(self.test_data, batch_size=32, shuffle=True)
self.text_prompts = text_prompts
self.processor = CLIPProcessor.from_pretrained(model_id)
self.model = CLIPModel.from_pretrained(model_id)
self.model.eval()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model.to(self.device)
self.precisions = []
self.recalls = []
self.f1s = []
self.patch_size = 90
self.window_size = 5
self.num_classes = 30
self.stride = 1
self.compare_threshold = .35
@torch.no_grad()
def analyse(self):
for imgs, labels, ibnd_box, index in self.data_loader:
pred_labels = []
for image, ibnd, ind in zip(imgs,ibnd_box, index):
patches = image.unfold(0,3,3)
patches = patches.unfold(1, self.patch_size, self.patch_size)
patches = patches.unfold(2, self.patch_size, self.patch_size)
scores = torch.zeros(patches.shape[1], patches.shape[2])
runs = torch.ones(patches.shape[1], patches.shape[2])
for Y in range(0, patches.shape[1]-self.window_size+1, self.stride):
for X in range(0, patches.shape[2]-self.window_size+1, self.stride):
big_patch = torch.zeros(self.patch_size*self.window_size, self.patch_size*self.window_size, 3)
patch_batch = patches[0: Y: Y+self.window_size, X:X+self.window_size]
for y in range(self.window_size):
for x in range(self.window_size):
big_patch[
y*self.patch_size:(y+1)*self.patch_size, x*self.patch_size:(x+1)*self.patch_size, :
] = patch_batch[y, x].permute(1, 2, 0)
inputs = self.processor(images=big_patch,
text=self.text_prompts,
padding=True,
return_tensors='pt'
).to(self.device)
score = self.model(**inputs).logits_per_image.cpu()
score = score.max().item()
scores[Y:Y+self.window_size, X:X+self.window_size] += score
runs[Y:Y+self.window_size, X:X+self.window_size] += 1
scores /= runs
for _ in range(5):
scores = np.clip(scores-scores.mean(), 0, np.inf)
scores = (scores - scores.min()) / (scores.max() - scores.min())
detection = scores > 0.5
if detection.any():
bnd_box = [
detection.nonzero()[:, 0].min().item()*self.patch_size,
detection.nonzero()[:, 1].min().item()*self.patch_size,
detection.nonzero()[:, 0].max().item()*self.patch_size + 1,
detection.nonzero()[:, 1].max().item()*self.patch_size + 1,
]
bnd_box = [int(axis*(224/900)) for axis in bnd_box]
pred_labels.append(self.get_label(bnd_box,ibnd,image['category']))
else:
pred_labels.append( ind* 2)
pred_labels = torch.tensor(pred_labels)
# Ensure predictions and labels are properly shaped and within range
pred_labels = torch.clamp(pred_labels, 0, self.num_classes - 1)
labels = torch.clamp(labels, 0, self.num_classes - 1)
metric = F1(average="macro", num_classes=31)
labels = torch.tensor(labels)
pred_labels = torch.tensor(pred_labels)
metric.update(pred_labels, labels)
print(pred_labels, len(pred_labels))
print(labels, len(labels))
self.f1s.append(metric.compute().item())
precision = multiclass_precision(pred_labels, labels,average="macro", num_classes=31).item()
recal = multiclass_recall(pred_labels, labels, average="macro", num_classes=31).item()
self.recalls.append(recal)
self.precisions.append(precision)
pcsn = torch.tensor(self.precisions).mean()
rcl = torch.tensor(self.recalls).mean()
f1_max = torch.tensor(self.f1s).max()
return pcsn, rcl, f1_max
def compare_box(self, box, target):
width = min(box[3], target[3]) - max(box[1],target[1])
height = min(box[2], target[2]) - max(box[0],target[0])
if width <= 0 or height <= 0:
return False
i_area = width * height
u_area = ((box[2]-box[0])*(box[3]-box[1])) + ((target[2]-target[0])* (target[3]-target[1])) - i_area
return i_area/u_area < self.compare_threshold
def get_label(self, box, target, indx):
return indx * 2 + (not self.compare_box(box, target))
prompts = [(f'a photo of a flawless {object}', f'a photo of a defective {object}') for object in categories]
prompts = [item for sublist in prompts for item in sublist]
analyzer = CLIP_ZSAD_Analyzer(model_id='openai/clip-vit-base-patch32',text_prompts=prompts)
pcsn, rcl, f1_max = analyzer.analyse()