I have network that has somewhere inside
For some inputs, however, the network crash and result contains NaN. I tracked down the problem probably to the
avg_pool = nn.AdaptiveAvgPool2d((2, 2))
print("IsInf x", torch.isinf(x).any())
print("IsNan x", torch.isnan(x).any())
x0 = self.avg_pool(x)
print("IsInf after avg pool", torch.isinf(x0).any())
print("IsNan after avg pool", torch.isnan(x0).any())
torch.Size([20, 32, 256, 256])
IsInf x tensor(False, device='cuda:0')
IsNan x tensor(False, device='cuda:0')
torch.Size([20, 32, 2, 2])
IsInf after avg pool tensor(True, device='cuda:0')
IsNan after avg pool tensor(False, device='cuda:0')
What can cause this behaviour?
Since you have already isolated it to a single tensor, store it and analyze its values to see how the
Inf output is calculated.
I have isolated single channel from all batched that is 256x256.
Based on stats:
min_val = -63.97, max_val = 13.19
Most of the values are something like
Why, or how, can
AdaptiveAvgPool2d end with
Inf? I dont know, how it is exactly calculated, but average value from such “small” numbers should not be
Inf. If some division with 0 occurs, I would expect
NaN result, not
0 should create Infs:
torch.randn(2, 2) / torch.tensor(0.)
# tensor([[-inf, -inf],
# [inf, inf]])
but I also don’t know how this result should be calculated assuming you are indeed not dividing by
Are you able to store the tensor and manually reproduce the output by running the pooling layer alone?
I have totally forgot, I have autocast enabled, so it may cause the problems.
Since the input tensor is large (256x256) and output is only 2x2, the cumulations maybe breaks float16 precission.
As for 0 division, you are right. I have mistaken it with sqrt operation of negative numbers.
Ah OK, in this case I assume the input is already in
float16 coming from another layer and uses this
dtype as its internal compute type.
We might need to check if the internal compute type should be in a larger floating point precision (
float32 in this case).
In most of the cases, for smaller sizes,
float16 would be fine. But this is rather large … is there some way, how to force single module compute precision to be
float32 during mixed precision in general?
Not directly, as you would need to change the backend. @eqy is proposing a fix using a wider
acctype during the accumulation. For now you could cast the inputs manually back to
torch.float32 to avoid these issues.