RankSEG: Optimize segmentation predictions for Dice/IoU without retraining

Hey PyTorch community,

If you’re deploying segmentation models (SAM, DeepLab, SegFormer, UNet, etc.), you’re probably using argmax on your output probabilities to get the final mask.

We built a tool called RankSEG that replaces argmax and directly optimizes for Dice/IoU metrics - giving you better results without any extra training.

Why use it?

  • Free Performance Boost: Squeezes out extra mIoU/Dice score (typically +0.5% to +1.5%) from your existing model
  • Zero Training Cost: Pure post-processing step - no training, no fine-tuning needed
  • Plug-and-Play: Works with any PyTorch model output

Quick Example

import torch
from rankseg import RankSEG

# Your existing model
logits = model(image) # [B, C, H, W]
probs = torch.softmax(logits, dim=1)

# Instead of argmax
# preds = probs.argmax(dim=1)

# Use RankSEG for better Dice/IoU
rankseg = RankSEG(metric='dice', solver='RMA')
preds = rankseg.predict(probs) # Optimized predictions

:light_bulb: Try it now: Open In Colab

More Information

This is based on our JMLR and NeurIPS papers on statistically consistent segmentation. The key insight: argmax doesn’t actually optimize for the metrics we care about (Dice, IoU), so we solve the direct optimization problem instead.

Links

Let me know if it works for your use case!

vision