Trying to use torch.nn.grad.conv1d_{input,weight}
as part of a custom backwards pass. The operation runs when the dtype is bfloat16
or float32
but fails with the aforementioned error when using float16
.
Any ideas why?
Here is a minimal repro:
import torch
from torch.nn.grad import conv1d_input, conv1d_weight
# Shapes
bs = 2
seqlen = 32
d = 64
g = 2
hl = 4
dg = d // g
# NOTE: Changing this to float16 results in `RuntimeError: CUDNN_BACKEND_OPERATION: cudnnFinalize Failed cudnn_status: CUDNN_STATUS_BAD_PARAM`
dtype = torch.bfloat16
# Inputs
x = torch.randn(bs, seqlen, g, dg, device="cuda", dtype=dtype)
x2 = x.reshape(bs, seqlen, -1).permute(0, 2, 1) # bs, d, seqlen
h = torch.randn(g, 1, hl, device="cuda", dtype=dtype)
h_grouped = h.repeat_interleave(dg, dim=0) # (d, 1, hl)
assert h_grouped.shape == torch.Size([d, 1, hl])
padding = hl - 1
groups = d
# depthwise causal conv
y = torch.nn.functional.conv1d(x2, h_grouped, groups=d, padding=padding)[..., :-padding]
assert y.shape == torch.Size([bs, d, seqlen])
dy = torch.randn_like(y)
# These ops will fail if dtype is set to `float16`
dx = conv1d_input(x2.shape, h_grouped, dy, padding=padding, groups=groups)
dh_grouped = conv1d_weight(x2, h_grouped.shape, dy, padding=padding, groups=groups)