Hi
I have this doubt
I want to create a confusion matrix and I’m a beginner. I’m getting as error at
This is evaluation function
Blockquote
pred = torch.tensor()
confusion_matrix = torch.zeros(num_classes, num_classes)
def evaluation(dataloader, model):
total, correct = 0,0
for data in dataloader:
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# inputs are put in the gpu
outputs = model(inputs)
_, pred = torch.max(outputs.data,1)
for t, p in zip(classes.view(-1), preds.view(-1)):
confusion_matrix[t.long(), p.long()] +=1
total += labels.size(0)
correct +=(pred==labels).sum().item()
return 100 * correct/total
This is the training loop
Blockquote
loss_epoch_arr =
max_epochs = 1
min_loss = 1000
n_iters = np.ceil(1050/batch_size)
for epoch in range(max_epochs):
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
#print(inputs.shape)
opt.zero_grad()
outputs = resnet(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
opt.step()
if min_loss> loss.item():
min_loss = loss.item()
best_model = copy.deepcopy(resnet.state_dict())
print('Min loss %0.2f' % min_loss)
if i%100 == 0:
print('Iteration : %d/%d, Loss : %0.2f' % (i, n_iters, loss.item()))
del inputs, labels, outputs
torch.cuda.empty_cache()
loss_epoch_arr.append(loss.item())
print('Epoch: %d/%d, Test acc: %0.2f, Train acc : %0.2f' % (epoch, max_epochs, evaluation(testloader,resnet),
evaluation(trainloader, resnet)))
plt.plot(loss_epoch_arr)
plt.show()
I’m getting the following error
Blockquote
AttributeError Traceback (most recent call last)
in
32 loss_epoch_arr.append(loss.item())
33
—> 34 print(‘Epoch: %d/%d, Test acc: %0.2f, Train acc : %0.2f’ % (epoch, max_epochs, evaluation(testloader,resnet),
35 evaluation(trainloader, resnet)))
36 plt.plot(loss_epoch_arr)
in evaluation(dataloader, model)
11 # from net to model - 2 change from lenet
12 _, pred = torch.max(outputs.data,1)
—> 13 for t, p in zip(classes.view(-1), preds.view(-1)):
14 confusion_matrix[t.long(), p.long()] +=1
15 total += labels.size(0)
AttributeError: ‘tuple’ object has no attribute ‘view’
Blockquote
Can someone help me with this?