`CUDNN_BACKEND_OPERATION: cudnnFinalize Failed cudnn_status: CUDNN_STATUS_BAD_PARAM`

Trying to use torch.nn.grad.conv1d_{input,weight} as part of a custom backwards pass. The operation runs when the dtype is bfloat16 or float32 but fails with the aforementioned error when using float16.

Any ideas why?

Here is a minimal repro:

import torch
from torch.nn.grad import conv1d_input, conv1d_weight
# Shapes
bs = 2
seqlen = 32
d = 64
g = 2
hl = 4
dg = d // g

# NOTE: Changing this to float16 results in `RuntimeError: CUDNN_BACKEND_OPERATION: cudnnFinalize Failed cudnn_status: CUDNN_STATUS_BAD_PARAM`
dtype = torch.bfloat16

# Inputs
x = torch.randn(bs, seqlen, g, dg, device="cuda", dtype=dtype)
x2 = x.reshape(bs, seqlen, -1).permute(0, 2, 1)  # bs, d, seqlen
h = torch.randn(g, 1, hl, device="cuda", dtype=dtype)

h_grouped = h.repeat_interleave(dg, dim=0)  # (d, 1, hl)
assert h_grouped.shape == torch.Size([d, 1, hl])

padding = hl - 1
groups = d

# depthwise causal conv
y = torch.nn.functional.conv1d(x2, h_grouped, groups=d, padding=padding)[..., :-padding]
assert y.shape == torch.Size([bs, d, seqlen])

dy = torch.randn_like(y)

# These ops will fail if dtype is set to `float16`
dx = conv1d_input(x2.shape, h_grouped, dy, padding=padding, groups=groups)
dh_grouped = conv1d_weight(x2, h_grouped.shape, dy, padding=padding, groups=groups)

Could you post the device you are using as well as the PyTorch version?

Torch version: 2.5.0.dev20240814+cu121
Device: NVIDIA H100 80GB HBM3

Thanks! I’m able to reproduce the issue and forwarded it.