Torch.where is breaking model training

I’m trying to divide one tensor by another in torch, but only when the values of the denominator exceed a certain threshold.

This implementation works, but won’t compile in torchdynamo.

wsq_ola = wsq_ola.to(wav).expand_as(wav).clone()
min_mask = wsq_ola.abs() < eps
wav[~min_mask] = wav[~min_mask] / wsq_ola[~min_mask]

I tried to implement the same thing with torch.where instead as follows:

wsq_ola = wsq_ola.to(wav).expand_as(wav).clone()
min_mask = wsq_ola.abs() < eps
wav = torch.where(min_mask, wav, wav / wsq_ola)

Unfortunately, once I make this change, the model no longer converges. Is there some issue with my use of torch.where here? For context, this is part of an stft layer with no trainable weights.

I cannot see any issues using your second code and get the same gradients:

wsq_ola = torch.randn(10, 10, requires_grad=False)
wav = torch.randn(10, 10, requires_grad=True)
eps = 0.1
min_mask = wsq_ola.abs() < eps
out = torch.where(min_mask, wav, wav / wsq_ola)
out.mean().backward()
print(wav.grad.abs().sum())
# tensor(1.9265)


def fun(wav, wsq_ola, eps):    
    out = torch.where(min_mask, wav, wav / wsq_ola)
    return out

wav.grad = None
fun = torch.compile(fun)
out = fun(wav, wsq_ola, eps)
out.mean().backward()
print(wav.grad.abs().sum())
# tensor(1.9265)

Do you see a difference using my code?

I don’t see any difference when I run your code or in my unit tests. However, when I start an actual training run, suddenly nothing works. Is there an alternate way to express this division? I only need it to get torchdynamo working.