Can I use nn.Dropout() to cfloat dtype?
Always the warning sign, “fused_dropout” not implemented for ‘ComplexFloat’, is popped up.
It seems that it is not implemented yet.
I’d probably work around it as
def dropout_complex(x):
# work around unimplemented dropout for complex
if x.is_complex():
mask = torch.nn.functional.dropout(torch.ones_like(x.real))
return x * mask
else:
return torch.nn.functional.dropout(x)
a = torch.randn(5, 5, dtype=torch.cfloat)
print(dropout_complex(a))
Best regards
Thomas
1 Like
How did you do that? you are genius