the following code only works if
optimizer.step()
is used. But if
optimizer.step
is in use, there is no training progress, so I believe the parameters are not updated, yet there is no warning or exception being thrown.
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import numpy as np
import torch.utils.data as DL
import torch.nn.functional as F
import time
#
TRF = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
#
T_DSET = dsets.MNIST(root= './data1', transform=TRF, train=True, download=True )
#
V_DSET = dsets.MNIST(root= './data1', transform=TRF, train=False, download=True )
TRAINLOADER = DL.DataLoader(dataset=T_DSET, batch_size=32, shuffle=True)
V_LOADER = DL.DataLoader(dataset=V_DSET, batch_size=32)
#
m = nn.LogSoftmax(dim= 1)
loss = nn.NLLLoss()
device = 'cuda'
#
class Model(nn.Module):
def __init__(self,H1, H2):
super(Model, self).__init__()
self.layer1 = nn.Linear(28*28, H1)
self.layer2 = nn.Linear(H1, H2)
self.layer3 = nn.Linear(H2, 10)
def forward(self, x):
S = x.shape[0]
x = x.view(S, -1)
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
x = (self.layer3(x))
x = m(x)
return x
#
MODEL = Model(256, 128)
#
optimizer = torch.optim.SGD(MODEL.parameters(), lr=0.01)
MODEL.to(device)
#
for epoch in range(10):
running_loss = 0
start = time.time()
for X,y in TRAINLOADER:
X, y = X.to(device), y.to(device)
optimizer.zero_grad()
yhat = MODEL(X)
LOSS = loss(yhat, y)
LOSS.backward()
optimizer.step
# optimizer.step()
running_loss = running_loss + LOSS.item()
end = time.time()
val_loss = 0
with torch.no_grad():
for XX, yy in V_LOADER:
XX, yy = XX.to(device), yy.to(device)
yhatv = MODEL(XX)
VL = loss(yhatv, yy)
val_loss = val_loss + VL.item()
print(running_loss/len(TRAINLOADER), val_loss/len(V_LOADER), (end-start))
#
#calculate the prediction probabilities for a given input image
MODEL.to('cpu')
my_iter = iter(V_DSET)
for i in range(55):
Kx, Ky = next(my_iter)
#
OUT = torch.exp(MODEL(Kx))
print('ist ',Ky)
print('Soll ',OUT)