AdaptiveAvgPool2d causes some data to contain Inf

I have network that has somewhere inside AdaptiveAvgPool2d block.
For some inputs, however, the network crash and result contains NaN. I tracked down the problem probably to the AdaptiveAvgPool2d .

avg_pool = nn.AdaptiveAvgPool2d((2, 2))

print(x.shape)
print("IsInf x", torch.isinf(x).any())
print("IsNan x", torch.isnan(x).any())
x0 = self.avg_pool(x)
print(x0.shape)
print("IsInf after avg pool", torch.isinf(x0).any())
print("IsNan after avg pool", torch.isnan(x0).any())

Which outputs:

torch.Size([20, 32, 256, 256])
IsInf x tensor(False, device='cuda:0')
IsNan x tensor(False, device='cuda:0')
torch.Size([20, 32, 2, 2])
IsInf after avg pool tensor(True, device='cuda:0')
IsNan after avg pool tensor(False, device='cuda:0')

What can cause this behaviour?

Since you have already isolated it to a single tensor, store it and analyze its values to see how the Inf output is calculated.

I have isolated single channel from all batched that is 256x256.
Based on stats:

min_val = -63.97, max_val = 13.19

Most of the values are something like Xe-01

Why, or how, can AdaptiveAvgPool2d end with Inf? I dont know, how it is exactly calculated, but average value from such “small” numbers should not be Inf. If some division with 0 occurs, I would expect NaN result, not Inf.

Dividing my 0 should create Infs:

torch.randn(2, 2) / torch.tensor(0.)
# tensor([[-inf, -inf],
#         [inf, inf]])

but I also don’t know how this result should be calculated assuming you are indeed not dividing by 0.
Are you able to store the tensor and manually reproduce the output by running the pooling layer alone?

I have totally forgot, I have autocast enabled, so it may cause the problems.
Since the input tensor is large (256x256) and output is only 2x2, the cumulations maybe breaks float16 precission.

As for 0 division, you are right. I have mistaken it with sqrt operation of negative numbers.

Ah OK, in this case I assume the input is already in float16 coming from another layer and uses this dtype as its internal compute type.
We might need to check if the internal compute type should be in a larger floating point precision (float32 in this case).

In most of the cases, for smaller sizes, float16 would be fine. But this is rather large … is there some way, how to force single module compute precision to be float32 during mixed precision in general?

Not directly, as you would need to change the backend. @eqy is proposing a fix using a wider acctype during the accumulation. For now you could cast the inputs manually back to torch.float32 to avoid these issues.