Torch.nn.functional.interpolate failing silently when OOM occurs during backward computation

error: runtimeError: CUDA error: invalid configuration argument
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

here is a minimal training code to reproduce (it depends on gpu memory - I am running on H100)

import torch
import torch.nn as nn
import torch.nn.functional as F


# Simplified model
class TinyModel(nn.Module):
    def __init__(self, scale_factor):
        super(TinyModel, self).__init__()
        self.conv = nn.Conv2d(32, 1, 3, padding=1)
        self.scale_factor = scale_factor

    def forward(self, x):
        W,H = x.size()[2:]
        x = self.conv(x)
        # Add interpolation
        x = F.interpolate(x, size=(W*self.scale_factor,H*self.scale_factor), mode='bilinear', align_corners=True)
        return x


# Create random sample
batch_size = 16
channels = 32
height = 32
width = 32
scale_factor = 80 *5
# Random input and target
input_tensor = torch.randn(batch_size, channels, height, width).cuda()
target = torch.randn(batch_size, 1, int(height * scale_factor), int(width * scale_factor)).cuda()  # Target size matches interpolated output

# Initialize model and optimizer
model = TinyModel(scale_factor=scale_factor).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss().cuda()

# Mini training loop
num_epochs = 10
for epoch in range(num_epochs):
    # Forward pass
    output = model(input_tensor)
    loss = criterion(output, target)

    # Backward pass and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

print("Training finished!")

I don’t think this issue is related to an OOM as I’m using ~0.5x of the available device memory when running into this error:

|=========================================+========================+======================|
|   0  NVIDIA H100 PCIe               On  |   00000000:C1:00.0 Off |                    0 |
| N/A   53C    P0             88W /  310W |   40592MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

The stacktrace points to:

#7  0x00007fffec47c12b in cudaLaunchKernel () from /usr/local/cuda/targets/x86_64-linux/lib/libcudart.so.12
#8  0x00007fff8a337084 in at::native::(anonymous namespace)::upsample_bilinear2d_backward_out_cuda_template(at::Tensor const&, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, std::optional<double>, std::optional<double>) [clone .isra.0] () from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cuda.so
#9  0x00007fff8aa30953 in at::(anonymous namespace)::wrapper_CUDA_upsample_bilinear2d_backward(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, std::optional<double>, std::optional<double>) ()
   from /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_cuda.so

Do you know which shapes are used in this kernel?

Not sure I understand the question. Can you elaborate?
I am not an expert of CUDA kernels.