### Problem Summary
When training a model with `F.grid_sample` using **Automatic Mixed Precision (AMP)** on **1024x1024 images**, the backward pass consistently fails with:
```
torch.linalg.solve: (Batch element 0): The solver failed because the input matrix is singular.
```
The error occurs in `grid_sampler_2d_backward_cuda` during gradient computation. Training works fine without AMP (full FP32), and AMP works on smaller images (256x256, 512x512).
### Environment
- **GPU**: NVIDIA GeForce RTX 3090 (24GB)
- **PyTorch**: 2.x (latest stable)
- **CUDA**: 11.8
- **OS**: Ubuntu 20.04
### Model Architecture
Parametric models that use coordinate transformations with `grid_sample`:
- Sequential model chaining `parametric_a1` and `parametric_a12`
- Both use `F.grid_sample` for image warping
- Input: (B, 3, 1024, 1024)
- Output: (B, 3, 1024, 1024)
### What We’ve Tested (All Failed)
**1. Conservative GradScaler Settings:**
```python
scaler = torch.cuda.amp.GradScaler(
init_scale=64, # Default: 65536
growth_factor=1.001, # Default: 2.0
backoff_factor=0.5,
growth_interval=5000 # Default: 2000
)
```
- Tested init_scale from 8 to 512
- Tested growth_factor from 1.00001 to 2.0
- Result:
Same singular matrix error
**2. Different Batch Sizes:**
- Batch size 1, 2, 3, 4, 6
- Result:
Same error regardless of batch size
**3. Different Learning Rates:**
- LR from 0.0000001 to 0.1 (7 orders of magnitude)
- With and without gradient clipping (0.01 to 50.0)
- Result:
Same error
**4. Different Data Types:**
- `torch.float16` (default AMP) -
Fails
- `torch.bfloat16` (on Ampere GPU) -
Same error
**5. Grid Sample Mode:**
- `mode=‘bilinear’` -
Fails
- `mode=‘nearest’` -
Fails
**6. Zero Weight Decay:**
- `weight_decay=0` -
Fails
**7. No Gradient Clipping:**
- Disabled clipping entirely -
Fails
**8. TF32 Backend Settings:**
```python
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
```
- Result:
Same error
### What Works
1. **Disable AMP** - Training works perfectly in FP32
2. **Use smaller images** - AMP works at 256x256 and 512x512
3. **Different architectures** - UNet++ (no grid_sample) works with AMP at 1024x1024
### The Core Question
**Is there any way to make `F.grid_sample` backward pass work with AMP on 1024x1024 images?**
Specifically, we’re looking for:
1. A way to force `grid_sample` backward to use FP32 while keeping rest of model in mixed precision?
2. Any undocumented AMP settings that affect linear algebra operations?
3. Is this a known PyTorch limitation with a roadmap fix?
4. Alternative approaches we haven’t considered?
### Minimal Reproducible Example
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
class SimpleGridSampleModel(nn.Module):
def _init_(self):
super()._init_()
# Simple coordinate transformation
self.transform = nn.Sequential(
nn.Linear(2, 64),
nn.ReLU(),
nn.Linear(64, 2),
nn.Sigmoid()
)
def forward(self, x):
B, C, H, W = x.shape
device = x.device
# Create coordinate grid
y, x_coords = torch.meshgrid(
torch.linspace(0, 1, H, device=device),
torch.linspace(0, 1, W, device=device),
indexing=‘ij’
)
coord_grid = torch.stack(\[x_coords, y\], dim=-1) # \[H, W, 2\]
coord_grid = coord_grid.unsqueeze(0).expand(B, -1, -1, -1) # \[B, H, W, 2\]
# Transform coordinates
flat_coords = coord_grid.view(B, -1, 2)
transformed = self.transform(flat_coords)
transformed = transformed.view(B, H, W, 2)
# Normalize to [-1, 1] for grid_sample
grid = transformed \* 2 - 1
# grid_sample - THIS FAILS in backward pass with AMP at 1024x1024
return F.grid_sample(x, grid, mode=‘bilinear’, padding_mode=‘border’, align_corners=True)
# Training setup
device = ‘cuda’
model = SimpleGridSampleModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scaler = GradScaler(init_scale=64, growth_factor=1.001) # Conservative settings
# 1024x1024 input - FAILS
data = torch.randn(1, 3, 1024, 1024).to(device)
optimizer.zero_grad()
with autocast():
output = model(data)
loss = F.mse_loss(output, data)
scaler.scale(loss).backward() #
torch.linalg.solve singular matrix error here
scaler.step(optimizer)
scaler.update()
```
### Additional Context
**Why we need AMP:**
- 1024x1024 images use ~11GB GPU memory per batch
- AMP reduces memory by ~30-40%, allowing larger batch sizes
- FP32 limits us to batch_size=1 with marginal memory headroom
- Training time is 2x slower without AMP
**The mathematical issue:**
`grid_sample` backward pass solves linear systems for bilinear interpolation gradients. At 1024x1024, the matrices are (H×W) × (H×W) = ~1M × 1M elements. FP16 precision loss causes near-singular matrices to become fully singular.
### Questions for the Community
1. Is there a way to selectively disable AMP for specific operations (like `grid_sample`) while keeping AMP enabled for the rest of the model?
2. Are there any PyTorch internals we can monkey-patch to make `grid_sample_backward` use FP32?
3. Is this on the PyTorch roadmap to fix? If so, what’s the timeline?
4. Are there alternative implementations of differentiable image warping that work with AMP on large images?
Any help or insights would be greatly appreciated!
-–
**Tags:** `automatic-mixed-precision`, `grid-sample`, `cuda`, `gradient-computation`, `numerical-stability`