Got nan in forward with `torch.amp`

Hi, guys,
I met nan error when using torch.amp like:

...
# forward pass
with torch.cuda.amp.autocast(self.use_amp):
    x = self.backbone(img)
    x = self.neck(x)

    if self.qa.check_nan:
        assert not x.isnan().any()
        assert not self.dict_feats["layer2"].isnan().any()
        for p in self.decoder2.parameters():
            assert not p.isnan().any()

    if isinstance(self.decoder2, FuseDecoder):
        encode_data = self.decoder2(x, self.dict_feats["layer2"])
    else:
        raise NotImplementedError

if self.qa.check_nan:
    assert not encode_data.isnan().any()
...

, and the error is like,

in train_epochs
output = self.model(batch_img)
File “/home/user/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1110, in _call_impl
return forward_call(*input, **kwargs)
File “/home/…net.py”, line 331, in forward
assert not encode_data.isnan().any()
AssertionError

What can I do to solve this issue?

By the way, I do not know whether these debug codes are enough to show the problem, so, any debugging suggestion will also be appreciated.

Thanks!

Could you describe your use case a bit more and e.g. which layer creates the invalid outputs?
It seems a FuseDecoder is used, but I don’t know what architecture this refers to.

1 Like

I tested it again, and got the nan from the backbone, as

def forward(self, img):
with torch.cuda.amp.autocast(self.use_amp):
    x = self.backbone(img)
    x = self.neck(x)

    if self.qa.check_nan:
        assert not x.isnan().any()  # Trigger the error 
        assert not self.dict_feats["layer2"].isnan().any()
        for p in self.decoder2.parameters():
            assert not p.isnan().any()
...

and the error is like,

File “/home/user/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1110, in _call_impl
return forward_call(*input, **kwargs)
File “/home/…net.py”, line 337, in forward
assert not x.isnan().any()
AssertionError

where

...
self.neck = nn.Identity()
...
self.backbone = shufflenetv2.shufflenet_v2_x0_5()
# which has been modified without using FC layer
...

Modifiled shufflenetv2 is like:

# Nearly same as the implementation as torchvision, 
# but without FC layer
def _forward_impl(self, x: Tensor) -> Tensor:
	# See note [TorchScript super()]
	x = self.conv1(x)
	x = self.maxpool(x)
	x = self.stage2(x)
	x = self.stage3(x)
	x = self.stage4(x)
	x = self.conv5(x)
	return x