How can I dropout for torch.complex data type?

Can I use nn.Dropout() to cfloat dtype?
Always the warning sign, “fused_dropout” not implemented for ‘ComplexFloat’, is popped up.

It seems that it is not implemented yet.

I’d probably work around it as

def dropout_complex(x):
    # work around unimplemented dropout for complex
    if x.is_complex():
        mask = torch.nn.functional.dropout(torch.ones_like(x.real))
        return x * mask
    else:
        return torch.nn.functional.dropout(x)

a = torch.randn(5, 5, dtype=torch.cfloat)
print(dropout_complex(a))

Best regards

Thomas

1 Like

How did you do that? you are genius :slight_smile: