Code:
import torch
import torch.nn as nn
img = torch.randn(4, 3, 256, 256)
small_img = torch.randn(4, 3, 64, 64)
class SuperResolution(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.Sequential(
nn.Conv2d(3, 128, 3, padding=1, bias=False),
nn.ReLU(True),
nn.Conv2d(128, 256, 3, padding=1, bias=False),
nn.ReLU(True),
nn.Conv2d(256, 512, 3, padding=1, bias=False),
nn.ReLU(True),
)
def forward(self, x):
print('x:', x.mean())
z = self.m1(x)
print('m1:', z.mean())
print('\n====================\n')
return x
device = torch.device('cuda') # static for debugging
model = SuperResolution().to(device)
small_img = small_img.to(device)
print('Normal Precision\n')
output = model(small_img)
print('Automatic Mixed Precision\n')
with torch.autocast('cuda', enabled=True):
output = model(small_img)
Output:
Normal Precision
x: tensor(0.0062, device='cuda:0')
m1: tensor(0.0014, device='cuda:0', grad_fn=<MeanBackward0>)
====================
Automatic Mixed Precision
x: tensor(0.0062, device='cuda:0')
m1: tensor(nan, device='cuda:0', dtype=torch.float16, grad_fn=<MeanBackward0>)
====================
Basically, with full precision it properly gives value as output of the network, but in autocast it gives nan.
Specs:
Fedora 40
Ryzen 7 4800H
NVIDIA GeForce GTX 1660 Ti
Nvidia-smi:
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.58.02 Driver Version: 555.58.02 CUDA Version: 12.5 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce GTX 1660 Ti Off | 00000000:01:00.0 On | N/A |
| N/A 53C P8 6W / 80W | 1675MiB / 6144MiB | 19% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
Pytorch installed with mamba (conda equivalent) as:
mamba create --name torch python=3.11.9
mamba activate torch
mamba install pytorch torchvision torchtext torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
Environment created and pytorch installed today (05.08.2024)