The problem is most likely that youâre hitting denormal numbers (float numbers too close to zero are very expensive to work with).
Does adding a torch.set_flush_denormal(True) before the op solves the issue?
This will reduce the precision of the operations for very very small numbers but should remove the slowdown.
Problem solved. For the sake of people who might see this post, let me summarize the problem and solution:
Problem is that floating-point operations on CPU can become very slow if numbers are âdenormalâ or âsubnormalâ. This means that values are very very small (smaller than 1.e-32). This is very unusual.
Workaround is to force CPU to treat these numbers as zeroes. At the beginning of your code add torch.set_flush_denormal(True) (this may not work on some older Intel CPUs though).
Alternative workaround is to âmanuallyâ remove denormals from your weights:
Have @pgmmpk or others noticed any common reasons that cause trainings to end up with denormal weights? I have been running some quick, proof-of-concept experiments with an opensource framework on Github, and their training code results in multiple denormal weights. Not having written the code myself, it would be helpful to have some pointers where to start looking
Notice that setting torch.set_flush_denormal(True) may cause accuracy regression. I suppose itâs because the denormal numbers are too many in the Convâs weight.
So the final solution has to be reduce or avoid denoraml weights in the training. Does anyone have best practise of it?
None more observasion is when updating the pytorch from 1.13 to 2.0. The denormal weight of same recipe increased. Any clue?