Depending on your CPU you might be able to use torch.set_flush_denormal(True)
to avoid a slower code path for denormal values (if that’s causing the slowdown, which could be the case based on your description).
1 Like