I encountered a weird issue when running inference with a custom made DETR model I trained… The forward pass of my model:
def forward(self, x):
# Pass inputs through the CNN backbone...
tokens = self.backbone(x)["layer4"]
print(f"Feature shape from backbone: {tokens.shape}") # Debug output
# Pass outputs from the backbone through a simple conv...
tokens = self.conv1x1(tokens)
# Re-order in patches format
tokens = rearrange(tokens, "b c h w -> b (h w) c")
# Pass encoded patches through encoder...
out_encoder = self.transformer_encoder(tokens + self.pe_encoder)
# We expand so each image of each batch get's it's own copy of the
# query embeddings. So from (1, 100, 256) to (4, 100, 256) for example
# for batch size=4, with 100 queries of embedding dimension 256.
# Then we pass through the decoder...
out_decoder = self.transformer_decoder(
self.queries.repeat(out_encoder.shape[0], 1, 1), out_encoder
)
# Compute outcomes for all intermediate
# decoder's layers...
# NOTE: The hook we registered previously is called during each
# forward pass and will store the layer's output in 'self.decoder_outs',
# where we can easily access them and then pass them through
# linear layers for prediction.
outs = {}
for n, o in self.decoder_outs.items():
outs[n] = {"cl": self.linear_class(o), "bbox": self.linear_bbox(o)}
return outs
Then with this inference code:
if model:
model = model.eval()
model.to(self.device)
else:
raise ValueError("No model provided for inference!")
if dataset is None:
raise ValueError("No validation dataset provided for inference!")
data_loader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)
inputs, (tgt_cl, tgt_bbox, tgt_mask, _) = next(iter(data_loader))
# Move inputs to GPU if available and run inference
inputs = inputs.to(self.device)
with torch.no_grad():
outputs = model(inputs)
out_cl, out_bbox = outputs["layer_5"].values()
out_bbox = out_bbox.sigmoid().cpu()
out_cl = out_cl.cpu()
fig, axs = plt.subplots(
batch_size, 2, figsize=(15, 7.5 * batch_size), constrained_layout=True
)
if batch_size == 1:
axs = axs[np.newaxis, :]
for ix in range(batch_size):
# Get true and predicted boxes for the batch
o_cl = out_cl[ix]
t_cl = tgt_cl[ix]
o_bbox = out_bbox[ix]
t_bbox = tgt_bbox[ix]
t_mask = tgt_mask[ix].bool()
# Filter out empty boxes from the ground truths
t_cl = t_cl[t_mask]
t_bbox = t_bbox[t_mask]
# Apply softmax and rescale boxes
o_probs = o_cl.softmax(dim=-1)
o_bbox = ops.box_convert(
o_bbox * image_size, in_fmt="cxcywh", out_fmt="xyxy"
)
t_bbox = ops.box_convert(
t_bbox * image_size, in_fmt="cxcywh", out_fmt="xyxy"
)
# Filter "no object" predictions
o_keep = o_probs.argmax(-1) != self.empty_class_id
keep_boxes = o_bbox[o_keep]
keep_probs = o_probs[o_keep]
# Apply class-based NMS
nms_boxes, nms_probs, nms_classes = class_based_nms(
keep_boxes, keep_probs, nms_threshold
)
num_filtered = nms_boxes.shape[0] - keep_boxes.shape[0]
if nms_boxes.shape[0] > 0 and num_filtered > 0:
print(f"Filtered out {num_filtered} boxes with NMS...")
# Plot image with predictions on the left
self._visualize_image(
inputs[ix].cpu(), nms_boxes, nms_classes, nms_probs, ax=axs[ix, 0]
)
axs[ix, 0].set_title("Predictions")
# Plot image with ground truth boxes on the right
self._visualize_image(inputs[ix].cpu(), t_bbox, t_cl, ax=axs[ix, 1])
axs[ix, 1].set_title("Ground Truth")
I am getting reasonable results with any batch size other than 1 while for batch size 1 its always no detections…? Why could that be…?