Observation
This proposal focuses on improving pretrained torchvision classification models through inference-time processing only.
- Model weights are unchanged
- No additional training is required
- Only image-level preprocessing and logit-level postprocessing are applied
- Evaluation is performed on the ImageNet-1K validation set
In this setting, we observe large improvements in top-1 accuracy across all torchvision-pretrained models.
For example, the boosted MobileNetV3-Small improves top-1 validation accuracy from 67.668% to 92.830% (+25.162), outperforming ViT-B/16 in:
- accuracy (+11.75%)
- parameter count (-97.1%)
- GFLOPs (-89.1%)
We provide (click here):
- full code for all torchvision classification models
- complete experiment logs
- an interactive notebook (
stabletta_3min_quick_repro.ipynb) for rapid reproduction (~3 minutes on a single GPU for StableTTA + MobileNetV2)
Some Results
| Model | Baseline Acc (%) | Ours (%) | Params (M) | GFLOPs |
|---|---|---|---|---|
| MobileNetV2 | 72.154 | 93.922 | 3.5 | 9.6 |
| MobileNetV3-Small | 67.668 | 92.830 | 2.5 | 1.92 |
| EfficientNet-B0 | 77.692 | 94.988 | 5.3 | 12.48 |
| ResNet50 | 80.858 | 95.018 | 25.6 | 130.88 |
| EfficientNetV2-S | 84.228 | 96.046 | 21.5 | 267.84 |
Baseline accuracies are from standard torchvision pretrained models under the reference evaluation setup (click here).
Pitch
Add an optional inference-time wrapper for torchvision classification models:
stable_tta = StableTTA()
logits = stable_tta.postprocess(model(stable_tta.preprocess(image))) # original: logits = model(image)