Conv2D is very slow on trained weights (vs random weights)

Super weird.

Doing torch.nn.functional.conv2d(x, weight) sometimes takes 30x times longer (using CPU). There are no NaNs and all data seem to be reasonable.

Anyone seen something like this?

See my gist for the minimal working example: https://gist.github.com/pgmmpk/b64901e3bd77ed58c00be83bd170d982

I see this in torch 0.4.1 and 1.0.1.post2 (did not try others). Installed from pip.

-Mike

1 Like

Hi,

The problem is most likely that you’re hitting denormal numbers (float numbers too close to zero are very expensive to work with).
Does adding a torch.set_flush_denormal(True) before the op solves the issue?
This will reduce the precision of the operations for very very small numbers but should remove the slowdown.

EDIT: cross post from github here.

Problem solved. For the sake of people who might see this post, let me summarize the problem and solution:

Problem is that floating-point operations on CPU can become very slow if numbers are “denormal” or “subnormal”. This means that values are very very small (smaller than 1.e-32). This is very unusual.

Workaround is to force CPU to treat these numbers as zeroes. At the beginning of your code add torch.set_flush_denormal(True) (this may not work on some older Intel CPUs though).

Alternative workaround is to “manually” remove denormals from your weights:

weight = ...
mask = (weight.abs() < 1.e-32).float()
weight = weight * mask

Better solution is to find out why your training ended up with such weird weights, and fix it!

Thanks to @albanD and others for the awesome support.

3 Likes

Have @pgmmpk or others noticed any common reasons that cause trainings to end up with denormal weights? I have been running some quick, proof-of-concept experiments with an opensource framework on Github, and their training code results in multiple denormal weights. Not having written the code myself, it would be helpful to have some pointers where to start looking

Notice that setting torch.set_flush_denormal(True) may cause accuracy regression. I suppose it’s because the denormal numbers are too many in the Conv’s weight.
So the final solution has to be reduce or avoid denoraml weights in the training. Does anyone have best practise of it?
None more observasion is when updating the pytorch from 1.13 to 2.0. The denormal weight of same recipe increased. Any clue?