Inference-time wrapper to improve SOTA torchvision classification accuracy on ImageNet-1K

Observation

This proposal focuses on improving pretrained torchvision classification models through inference-time processing only.

  • Model weights are unchanged
  • No additional training is required
  • Only image-level preprocessing and logit-level postprocessing are applied
  • Evaluation is performed on the ImageNet-1K validation set

In this setting, we observe large improvements in top-1 accuracy across all torchvision-pretrained models.

For example, the boosted MobileNetV3-Small improves top-1 validation accuracy from 67.668% to 92.830% (+25.162), outperforming ViT-B/16 in:

  • accuracy (+11.75%)
  • parameter count (-97.1%)
  • GFLOPs (-89.1%)

We provide (click here):

  • full code for all torchvision classification models
  • complete experiment logs
  • an interactive notebook (stabletta_3min_quick_repro.ipynb) for rapid reproduction (~3 minutes on a single GPU for StableTTA + MobileNetV2)

Some Results

Model Baseline Acc (%) Ours (%) Params (M) GFLOPs
MobileNetV2 72.154 93.922 3.5 9.6
MobileNetV3-Small 67.668 92.830 2.5 1.92
EfficientNet-B0 77.692 94.988 5.3 12.48
ResNet50 80.858 95.018 25.6 130.88
EfficientNetV2-S 84.228 96.046 21.5 267.84

Baseline accuracies are from standard torchvision pretrained models under the reference evaluation setup (click here).

Pitch

Add an optional inference-time wrapper for torchvision classification models:

stable_tta = StableTTA()
logits = stable_tta.postprocess(model(stable_tta.preprocess(image)))  # original: logits = model(image)

The git repo seems to be not there anymore, have you made it private?