Hi, I am trying to use the model.half() to get rid of out of memory error. So, first I want to check it for a simple model and then try it for a complex model.
In the following example, I am using model.half()
for a simple CNN and MNIST dataset. I think I did all thing correctly, but I got Nan
for loss.
I am wondering if you kindly let me know where I am doing a mistake.
mnist = torchvision.datasets.MNIST('./data',train= True,download=True,transform =transforms.ToTensor())
data_loader = DataLoader(mnist,batch_size=20,num_workers=2,shuffle=True)
class Model(nn.Module):
# Our model
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Conv2d(1,10,3)
self.bn1 = nn.BatchNorm2d(10)
self.fc2= nn.Conv2d(10,20,3)
self.bn2 = nn.BatchNorm2d(20)
self.fc3= nn.Linear(11520,10)
def forward(self,x):
x = F.relu(self.fc1(x))
x = x.type(torch.float32)
x = self.bn1(x)
x = x.type(torch.float16)
x = F.relu(self.fc2(x))
x = x.type(torch.float32)
x = self.bn2(x)
x = x.type(torch.float16)
x = x.view(x.size(0),-1)
x = self.fc3(x)
return(x)
device = torch.device('cuda:6' if torch.cuda.is_available() else 'cpu')
model = Model().half().to(device)
for layer in model.modules():
if isinstance(layer,nn.BatchNorm2d):
layer.float()
optimizer = optim.Adam(model.parameters(),lr=0.1)
lr_sch = lr_scheduler.StepLR(optimizer,step_size=2,gamma=0.1)
criterion = nn.CrossEntropyLoss()
def train(epoch):
model.train()
t_loss = 0
for X,y in data_loader:
X= X.half().to(device)
y = y.long().to(device)
pred = model(X)
loss = criterion(pred,y)
t_loss+= loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
return(t_loss/len(data_loader.dataset))
num_epochs = 20
train_loss = []
for epoch in range(num_epochs):
t_loss = train(epoch)
train_loss.append(t_loss)
print(f'Epoch: {epoch}: Loss: {t_loss:0.4}')
Epoch: 0: Loss: 4.346e-05
Epoch: 1: Loss: nan
Epoch: 2: Loss: nan
Epoch: 3: Loss: nan
Epoch: 4: Loss: nan
Epoch: 5: Loss: nan
Epoch: 6: Loss: nan
Epoch: 7: Loss: nan
Epoch: 8: Loss: nan
Epoch: 9: Loss: nan
Epoch: 10: Loss: nan
Epoch: 11: Loss: nan
Epoch: 12: Loss: nan
Epoch: 13: Loss: nan
Epoch: 14: Loss: nan
Epoch: 15: Loss: nan
Epoch: 16: Loss: nan
Epoch: 17: Loss: nan
Epoch: 18: Loss: nan
Epoch: 19: Loss: nan