Hello.
This might be a bit of a long questions to ask but I really can’t figure out by myself and any help is appreciated.
I was working on a tutorial on zero-shot-learning and below is the Colab script that I am working on, which is from James Briggs’ tutorial (Zero Shot Object Detection with OpenAI's CLIP | Pinecone).
Question 1: I am assuming that the below dataset (jamescalam/image-text-demo) is what contains all necessary class labels and training. Then, how can I prepare my own data so I can control the list of classes I want to detect?
from datasets import load_dataset #pip install datasets
data = load_dataset(
"jamescalam/image-text-demo",
split="train",
revision="180fdae"
)
data
Question 2: Below is the core part of the detection model.
from tqdm.auto import tqdm
import matplotlib.patches as patches
colors = ['#FAFF00', '#8CF1FF']
def get_patches(img, patch_size=256):
# add extra dimension for later calculations
img_patches = img.data.unfold(0,3,3)
# break the image into patches (in height dimension)
img_patches = img_patches.unfold(1, patch_size, patch_size)
# break the image into patches (in width dimension)
img_patches = img_patches.unfold(2, patch_size, patch_size)
return img_patches
def get_scores(img_patches, prompt, window=6, stride=1):
# initialize scores and runs arrays
scores = torch.zeros(img_patches.shape[1], img_patches.shape[2])
runs = torch.ones(img_patches.shape[1], img_patches.shape[2])
# iterate through patches
for Y in range(0, img_patches.shape[1]-window+1, stride):
for X in range(0, img_patches.shape[2]-window+1, stride):
# initialize array to store big patches
big_patch = torch.zeros(patch*window, patch*window, 3)
# get a single big patch
patch_batch = img_patches[0, Y:Y+window, X:X+window]
# iteratively build all big patches
for y in range(window):
for x in range(window):
big_patch[y*patch:(y+1)*patch, x*patch:(x+1)*patch, :] = patch_batch[y, x].permute(1, 2, 0)
inputs = processor(
images=big_patch, # image trasmitted to the model
return_tensors="pt", # return pytorch tensor
text=prompt, # text trasmitted to the model
padding=True
).to(device) # move to device if possible
score = model(**inputs).logits_per_image.item()
# sum up similarity scores
scores[Y:Y+window, X:X+window] += score
# calculate the number of runs
runs[Y:Y+window, X:X+window] += 1
# calculate average scores
scores /= runs
# clip scores
for _ in range(3):
scores = np.clip(scores-scores.mean(), 0, np.inf)
# normalize scores
scores = (scores - scores.min()) / (scores.max() - scores.min())
return scores
def get_box(scores, patch_size=256, threshold=0.5):
detection = scores > threshold
# find box corners
y_min, y_max = np.nonzero(detection)[:,0].min().item(), np.nonzero(detection)[:,0].max().item()+1
x_min, x_max = np.nonzero(detection)[:,1].min().item(), np.nonzero(detection)[:,1].max().item()+1
# convert from patch co-ords to pixel co-ords
y_min *= patch_size
y_max *= patch_size
x_min *= patch_size
x_max *= patch_size
# calculate box height and width
height = y_max - y_min
width = x_max - x_min
return x_min, y_min, width, height
def detect(prompts, img, patch_size=256, window=6, stride=1, threshold=0.5):
# build image patches for detection
img_patches = get_patches(img, patch_size)
# convert image to format for displaying with matplotlib
image = np.moveaxis(img.data.numpy(), 0, -1)
# initialize plot to display image + bounding boxes
fig, ax = plt.subplots(figsize=(Y*0.5, X*0.5))
ax.imshow(image)
# process image through object detection steps
for i, prompt in enumerate(tqdm(prompts)):
scores = get_scores(img_patches, prompt, window, stride)
x, y, width, height = get_box(scores, patch_size, threshold)
# create the bounding box
rect = patches.Rectangle((x, y), width, height, linewidth=3, edgecolor=colors[i], facecolor='none')
# add the patch to the Axes
ax.add_patch(rect)
plt.show()
Then, I can use the below code to detect the object class which works fine.
detect(["cat", "butterfly"], img, window=6, stride=1)
But if I try with more than 3 class detection, it failed.
detect(["cat eye", "butterfly", "white fur", "cat nose"], img, window=6, stride=1)
IndexError Traceback (most recent call last)
<ipython-input-63-6cca2863a857> in <module>
1 #detect(["cat", "butterfly"], img, window=6, stride=1)
----> 2 detect(["cat eye", "butterfly", "white fur", "cat nose"], img, window=6, stride=1)
3 #detect(["white fur"], img, window=6, stride=1)
<ipython-input-51-f10b7a35c9bb> in detect(prompts, img, patch_size, window, stride, threshold)
79 x, y, width, height = get_box(scores, patch_size, threshold)
80 # create the bounding box
---> 81 rect = patches.Rectangle((x, y), width, height, linewidth=3, edgecolor=colors[i], facecolor='none')
82 # add the patch to the Axes
83 ax.add_patch(rect)
IndexError: list index out of range
What should I adjust so I can detect more than 2 classes?
Thank you!