Strange inference behaviour with batch_size = 1

I encountered a weird issue when running inference with a custom made DETR model I trained… The forward pass of my model:

    def forward(self, x):
        # Pass inputs through the CNN backbone...
        tokens = self.backbone(x)["layer4"]
        print(f"Feature shape from backbone: {tokens.shape}")  # Debug output

        # Pass outputs from the backbone through a simple conv...
        tokens = self.conv1x1(tokens)

        # Re-order in patches format
        tokens = rearrange(tokens, "b c h w -> b (h w) c")

        # Pass encoded patches through encoder...
        out_encoder = self.transformer_encoder(tokens + self.pe_encoder)

        # We expand so each image of each batch get's it's own copy of the
        # query embeddings. So from (1, 100, 256) to (4, 100, 256) for example
        # for batch size=4, with 100 queries of embedding dimension 256.
        # Then we pass through the decoder...
        out_decoder = self.transformer_decoder(
            self.queries.repeat(out_encoder.shape[0], 1, 1), out_encoder

        # Compute outcomes for all intermediate
        # decoder's layers...
        # NOTE: The hook we registered previously is called during each
        # forward pass and will store the layer's output in 'self.decoder_outs',
        # where we can easily access them and then pass them through
        # linear layers for prediction.
        outs = {}
        for n, o in self.decoder_outs.items():
            outs[n] = {"cl": self.linear_class(o), "bbox": self.linear_bbox(o)}

        return outs

Then with this inference code:

        if model:
            model = model.eval()
            raise ValueError("No model provided for inference!")

        if dataset is None:
            raise ValueError("No validation dataset provided for inference!")

        data_loader = DataLoader(
            dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
        inputs, (tgt_cl, tgt_bbox, tgt_mask, _) = next(iter(data_loader))

        # Move inputs to GPU if available and run inference
        inputs =
        with torch.no_grad():
            outputs = model(inputs)
        out_cl, out_bbox = outputs["layer_5"].values()
        out_bbox = out_bbox.sigmoid().cpu()
        out_cl = out_cl.cpu()

        fig, axs = plt.subplots(
            batch_size, 2, figsize=(15, 7.5 * batch_size), constrained_layout=True
        if batch_size == 1:
            axs = axs[np.newaxis, :]

        for ix in range(batch_size):
            # Get true and predicted boxes for the batch
            o_cl = out_cl[ix]
            t_cl = tgt_cl[ix]
            o_bbox = out_bbox[ix]
            t_bbox = tgt_bbox[ix]
            t_mask = tgt_mask[ix].bool()

            # Filter out empty boxes from the ground truths
            t_cl = t_cl[t_mask]
            t_bbox = t_bbox[t_mask]

            # Apply softmax and rescale boxes
            o_probs = o_cl.softmax(dim=-1)
            o_bbox = ops.box_convert(
                o_bbox * image_size, in_fmt="cxcywh", out_fmt="xyxy"
            t_bbox = ops.box_convert(
                t_bbox * image_size, in_fmt="cxcywh", out_fmt="xyxy"

            # Filter "no object" predictions
            o_keep = o_probs.argmax(-1) != self.empty_class_id

            keep_boxes = o_bbox[o_keep]
            keep_probs = o_probs[o_keep]

            # Apply class-based NMS
            nms_boxes, nms_probs, nms_classes = class_based_nms(
                keep_boxes, keep_probs, nms_threshold
            num_filtered = nms_boxes.shape[0] - keep_boxes.shape[0]

            if nms_boxes.shape[0] > 0 and num_filtered > 0:
                print(f"Filtered out {num_filtered} boxes with NMS...")

            # Plot image with predictions on the left
                inputs[ix].cpu(), nms_boxes, nms_classes, nms_probs, ax=axs[ix, 0]
            axs[ix, 0].set_title("Predictions")

            # Plot image with ground truth boxes on the right
            self._visualize_image(inputs[ix].cpu(), t_bbox, t_cl, ax=axs[ix, 1])
            axs[ix, 1].set_title("Ground Truth")

I am getting reasonable results with any batch size other than 1 while for batch size 1 its always no detections…? Why could that be…?