I am getting above error while executing following function. My module runs very well for Training and Validation, but while in testing phase it gives me above error. Any suggestion/solution are welcome @ptrblck
def evaluate_test(self):
# restore model args
args = self.args
# evaluation mode
self.model.load_state_dict(torch.load(osp.join(self.args.save_path, 'max_acc.pth'))['params'])
self.model.eval()
record = np.zeros((2700, 2)) # loss and acc
label = torch.arange(args.eval_way, dtype=torch.int16).repeat(args.eval_query)
label = label.type(torch.LongTensor)
if torch.cuda.is_available():
label = label.cuda()
print('best epoch {}, best val acc={:.4f} + {:.4f}'.format(
self.trlog['max_acc_epoch'],
self.trlog['max_acc'],
self.trlog['max_acc_interval']))
confusion_matrix=torch.zeros(args.eval_way,args.eval_way)
with torch.no_grad():
for i, batch in tqdm(enumerate(self.test_loader, 1)):
if torch.cuda.is_available():
data, real_label= [_.cuda() for _ in batch]
else:
data = batch[0]
real_label = real_label.type(torch.LongTensor)[args.eval_way:]
if torch.cuda.is_available():
real_label = real_label.cuda()
logits = self.model(data)
loss = F.cross_entropy(logits, label)
pred = torch.argmax(logits, dim=1)
actual_pred=torch.where(pred.eq(label),real_label,real_label[label[pred]])
#print('actual_label is:', real_label)
#print('actual_pred is:',actual_pred)
acc = count_acc(logits, label)
record[i-1, 0] = loss.item()
#print('record is:', record)
record[i-1, 1] = acc
real_label = batch[1]
#print(real_label)
for t,p in zip(real_label.view(-1), actual_pred.view(-1)):
confusion_matrix[t.long(),p.long()]+=1
cm=confusion_matrix.numpy()
import pandas as pd
cm_col=cm / cm.sum(axis=1)
cm_row=cm / cm.sum(axis=0)
It is causing me error at this line in the above error
record[i-1, 0] = loss.item()