Model forward pass in AMP gives NaN

Code:

import torch
import torch.nn as nn

img = torch.randn(4, 3, 256, 256)
small_img = torch.randn(4, 3, 64, 64)

class SuperResolution(nn.Module):
    def __init__(self):
        super().__init__()

        self.m1 = nn.Sequential(
            nn.Conv2d(3, 128, 3, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 3, padding=1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(256, 512, 3, padding=1, bias=False),
            nn.ReLU(True),
        )
   
    def forward(self, x):
        print('x:', x.mean())
        z = self.m1(x)
        print('m1:', z.mean())
        print('\n====================\n')
        return x

device = torch.device('cuda')  # static for debugging
model = SuperResolution().to(device)
small_img = small_img.to(device)

print('Normal Precision\n')
output = model(small_img)

print('Automatic Mixed Precision\n')
with torch.autocast('cuda', enabled=True):
    output = model(small_img)

Output:

Normal Precision

x: tensor(0.0062, device='cuda:0')
m1: tensor(0.0014, device='cuda:0', grad_fn=<MeanBackward0>)

====================

Automatic Mixed Precision

x: tensor(0.0062, device='cuda:0')
m1: tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)

====================

Basically, with full precision it properly gives value as output of the network, but in autocast it gives nan.

Specs:

Fedora 40
Ryzen 7 4800H
NVIDIA GeForce GTX 1660 Ti

Nvidia-smi:

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.58.02              Driver Version: 555.58.02      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce GTX 1660 Ti     Off |   00000000:01:00.0  On |                  N/A |
| N/A   53C    P8              6W /   80W |    1675MiB /   6144MiB |     19%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

Pytorch installed with mamba (conda equivalent) as:

mamba create --name torch python=3.11.9
mamba activate torch
mamba install pytorch torchvision torchtext torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia

Environment created and pytorch installed today (05.08.2024)