the file of this function is imported and when running (python train.py --epochs 1) i got this error:
--------------Training is starting-------------
Traceback (most recent call last):
File "train.py", line 48, in <module>
suppFunctions.train_network(model, criterion, optimizer, trainloader, epochs, 20, power)
File "/home/workspace/ImageClassifier/suppFunctions.py", line 128, in train_network
loss = criterion(outputs, labels)
TypeError: 'Adam' object is not callable
the function train_network
:
def train_network(model, criterion, optimizer, loader, epochs = 3, print_every=20, power='gpu'):
steps = 0
running_loss = 0
print("--------------Training is starting------------- ")
for e in range(epochs):
running_loss = 0
for ii, (inputs, labels) in enumerate(loader):
steps += 1
if torch.cuda.is_available() and power =='gpu':
inputs, labels = inputs.to('cuda'), labels.to('cuda')
optimizer.zero_grad()
# Forward and backward passes
outputs = model.forward(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if steps % print_every == 0:
model.eval()
vlost = 0
accuracy=0
for ii, (inputs2,labels2) in enumerate(vloader):
optimizer.zero_grad()
if torch.cuda.is_available():
inputs2, labels2 = inputs2.to('cuda:0') , labels2.to('cuda:0')
model.to('cuda:0')
with torch.no_grad():
outputs = model.forward(inputs2)
vlost = criterion(outputs,labels2)
ps = torch.exp(outputs).data
equality = (labels2.data == ps.max(1)[1])
accuracy += equality.type_as(torch.FloatTensor()).mean()
vlost = vlost / len(vloader)
accuracy = accuracy /len(vloader)
print("Epoch: {}/{}... ".format(e+1, epochs),
"Loss: {:.4f}".format(running_loss/print_every),
"Validation Lost {:.4f}".format(vlost),
"Accuracy: {:.4f}".format(accuracy))
running_loss = 0
print("-------------- Finished training -----------------------")
print("Dear User I the ulitmate NN machine trained your model. It required")
print("----------Epochs: {}------------------------------------".format(epochs))
print("----------Steps: {}-----------------------------".format(steps))
print("That's a lot of steps")