torch.linalg.solve singular matrix in grid_sample backward pass with AMP on large images - all hyperparameter combinations tested

### Problem Summary

When training a model with `F.grid_sample` using **Automatic Mixed Precision (AMP)** on **1024x1024 images**, the backward pass consistently fails with:

```

torch.linalg.solve: (Batch element 0): The solver failed because the input matrix is singular.

```

The error occurs in `grid_sampler_2d_backward_cuda` during gradient computation. Training works fine without AMP (full FP32), and AMP works on smaller images (256x256, 512x512).

### Environment

- **GPU**: NVIDIA GeForce RTX 3090 (24GB)

- **PyTorch**: 2.x (latest stable)

- **CUDA**: 11.8

- **OS**: Ubuntu 20.04

### Model Architecture

Parametric models that use coordinate transformations with `grid_sample`:

- Sequential model chaining `parametric_a1` and `parametric_a12`

- Both use `F.grid_sample` for image warping

- Input: (B, 3, 1024, 1024)

- Output: (B, 3, 1024, 1024)

### What We’ve Tested (All Failed)

**1. Conservative GradScaler Settings:**

```python

scaler = torch.cuda.amp.GradScaler(

init_scale=64, # Default: 65536

growth_factor=1.001, # Default: 2.0

backoff_factor=0.5,

growth_interval=5000 # Default: 2000

)

```

- Tested init_scale from 8 to 512

- Tested growth_factor from 1.00001 to 2.0

- Result: :cross_mark: Same singular matrix error

**2. Different Batch Sizes:**

- Batch size 1, 2, 3, 4, 6

- Result: :cross_mark: Same error regardless of batch size

**3. Different Learning Rates:**

- LR from 0.0000001 to 0.1 (7 orders of magnitude)

- With and without gradient clipping (0.01 to 50.0)

- Result: :cross_mark: Same error

**4. Different Data Types:**

- `torch.float16` (default AMP) - :cross_mark: Fails

- `torch.bfloat16` (on Ampere GPU) - :cross_mark: Same error

**5. Grid Sample Mode:**

- `mode=‘bilinear’` - :cross_mark: Fails

- `mode=‘nearest’` - :cross_mark: Fails

**6. Zero Weight Decay:**

- `weight_decay=0` - :cross_mark: Fails

**7. No Gradient Clipping:**

- Disabled clipping entirely - :cross_mark: Fails

**8. TF32 Backend Settings:**

```python

torch.backends.cuda.matmul.allow_tf32 = True

torch.backends.cudnn.allow_tf32 = True

```

- Result: :cross_mark: Same error

### What Works

1. **Disable AMP** - Training works perfectly in FP32

2. **Use smaller images** - AMP works at 256x256 and 512x512

3. **Different architectures** - UNet++ (no grid_sample) works with AMP at 1024x1024

### The Core Question

**Is there any way to make `F.grid_sample` backward pass work with AMP on 1024x1024 images?**

Specifically, we’re looking for:

1. A way to force `grid_sample` backward to use FP32 while keeping rest of model in mixed precision?

2. Any undocumented AMP settings that affect linear algebra operations?

3. Is this a known PyTorch limitation with a roadmap fix?

4. Alternative approaches we haven’t considered?

### Minimal Reproducible Example

```python

import torch

import torch.nn as nn

import torch.nn.functional as F

from torch.cuda.amp import autocast, GradScaler

class SimpleGridSampleModel(nn.Module):

def _init_(self):

super()._init_()

# Simple coordinate transformation

self.transform = nn.Sequential(

        nn.Linear(2, 64),

        nn.ReLU(),

        nn.Linear(64, 2),

        nn.Sigmoid()

    )

def forward(self, x):

    B, C, H, W = x.shape

    device = x.device

# Create coordinate grid

    y, x_coords = torch.meshgrid(

        torch.linspace(0, 1, H, device=device),

        torch.linspace(0, 1, W, device=device),

indexing=‘ij’

    )

    coord_grid = torch.stack(\[x_coords, y\], dim=-1)  # \[H, W, 2\]

    coord_grid = coord_grid.unsqueeze(0).expand(B, -1, -1, -1)  # \[B, H, W, 2\]

# Transform coordinates

    flat_coords = coord_grid.view(B, -1, 2)

    transformed = self.transform(flat_coords)

    transformed = transformed.view(B, H, W, 2)

# Normalize to [-1, 1] for grid_sample

    grid = transformed \* 2 - 1

# grid_sample - THIS FAILS in backward pass with AMP at 1024x1024

return F.grid_sample(x, grid, mode=‘bilinear’, padding_mode=‘border’, align_corners=True)

# Training setup

device = ‘cuda’

model = SimpleGridSampleModel().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

scaler = GradScaler(init_scale=64, growth_factor=1.001) # Conservative settings

# 1024x1024 input - FAILS

data = torch.randn(1, 3, 1024, 1024).to(device)

optimizer.zero_grad()

with autocast():

output = model(data)

loss = F.mse_loss(output, data)

scaler.scale(loss).backward() # :cross_mark: torch.linalg.solve singular matrix error here

scaler.step(optimizer)

scaler.update()

```

### Additional Context

**Why we need AMP:**

- 1024x1024 images use ~11GB GPU memory per batch

- AMP reduces memory by ~30-40%, allowing larger batch sizes

- FP32 limits us to batch_size=1 with marginal memory headroom

- Training time is 2x slower without AMP

**The mathematical issue:**

`grid_sample` backward pass solves linear systems for bilinear interpolation gradients. At 1024x1024, the matrices are (H×W) × (H×W) = ~1M × 1M elements. FP16 precision loss causes near-singular matrices to become fully singular.

### Questions for the Community

1. Is there a way to selectively disable AMP for specific operations (like `grid_sample`) while keeping AMP enabled for the rest of the model?

2. Are there any PyTorch internals we can monkey-patch to make `grid_sample_backward` use FP32?

3. Is this on the PyTorch roadmap to fix? If so, what’s the timeline?

4. Are there alternative implementations of differentiable image warping that work with AMP on large images?

Any help or insights would be greatly appreciated!

-–

**Tags:** `automatic-mixed-precision`, `grid-sample`, `cuda`, `gradient-computation`, `numerical-stability`