Without AMP
import torch
from torch.cuda.amp import autocast, GradScaler
from torchvision import models
model = models.mobilenet_v2(pretrained=True).cuda()
loss_fnc = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
X = torch.randn((32,3,300,300), dtype=torch.float32).cuda()
y = torch.randint(0, 1000, (32,), dtype=torch.long).cuda()
model.train()
for j in range(30):
optimizer.zero_grad()
y_hat = model(X)
loss = loss_fnc(y_hat, y)
loss.backward()
optimizer.step()
print (loss.item())
Output:
8.039933204650879
5.690041542053223
3.4787116050720215
1.607206106185913
0.6231755614280701
0.23825135827064514
0.08544095605611801
0.04335329309105873
0.016259444877505302
0.01174827478826046
0.0069425650872290134
0.004459714516997337
0.003734807949513197
0.0024659112095832825
0.0027059323620051146
........................
But with AMP it gives nan’s
model = models.mobilenet_v2(pretrained=True).cuda()
loss_fnc = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
X = torch.randn((32,3,300,300), dtype=torch.float32).cuda()
y = torch.randint(0, 1000, (32,), dtype=torch.long).cuda()
scaler = GradScaler()
model.train()
for j in range(30):
optimizer.zero_grad()
with autocast():
y_hat = model(X)
loss = loss_fnc(y_hat, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
print (loss.item())
Output:
8.393239974975586
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
.......................