I have network that has somewhere inside AdaptiveAvgPool2d
block.
For some inputs, however, the network crash and result contains NaN. I tracked down the problem probably to the AdaptiveAvgPool2d
.
avg_pool = nn.AdaptiveAvgPool2d((2, 2))
print(x.shape)
print("IsInf x", torch.isinf(x).any())
print("IsNan x", torch.isnan(x).any())
x0 = self.avg_pool(x)
print(x0.shape)
print("IsInf after avg pool", torch.isinf(x0).any())
print("IsNan after avg pool", torch.isnan(x0).any())
Which outputs:
torch.Size([20, 32, 256, 256])
IsInf x tensor(False, device='cuda:0')
IsNan x tensor(False, device='cuda:0')
torch.Size([20, 32, 2, 2])
IsInf after avg pool tensor(True, device='cuda:0')
IsNan after avg pool tensor(False, device='cuda:0')
What can cause this behaviour?
Since you have already isolated it to a single tensor, store it and analyze its values to see how the Inf
output is calculated.
I have isolated single channel from all batched that is 256x256.
Based on stats:
min_val = -63.97, max_val = 13.19
Most of the values are something like Xe-01
Why, or how, can AdaptiveAvgPool2d
end with Inf
? I dont know, how it is exactly calculated, but average value from such “small” numbers should not be Inf
. If some division with 0 occurs, I would expect NaN
result, not Inf
.
Dividing my 0
should create Infs:
torch.randn(2, 2) / torch.tensor(0.)
# tensor([[-inf, -inf],
# [inf, inf]])
but I also don’t know how this result should be calculated assuming you are indeed not dividing by 0
.
Are you able to store the tensor and manually reproduce the output by running the pooling layer alone?
I have totally forgot, I have autocast enabled, so it may cause the problems.
Since the input tensor is large (256x256) and output is only 2x2, the cumulations maybe breaks float16 precission.
As for 0 division, you are right. I have mistaken it with sqrt operation of negative numbers.
Ah OK, in this case I assume the input is already in float16
coming from another layer and uses this dtype
as its internal compute type.
We might need to check if the internal compute type should be in a larger floating point precision (float32
in this case).
In most of the cases, for smaller sizes, float16
would be fine. But this is rather large … is there some way, how to force single module compute precision to be float32
during mixed precision in general?
Not directly, as you would need to change the backend. @eqy is proposing a fix using a wider acctype
during the accumulation. For now you could cast the inputs manually back to torch.float32
to avoid these issues.