I use load pre-train model and test it with the same data under eval mode but get a different result each time
gnn = DAGNN(in_dim = args.in_dim,n_hid = args.n_hid, n_heads = args.n_heads, n_layers = args.n_layers,agnostic_type = 0,\
dropout = args.dropout,device=device,shared_node_type=shared_node_types,shared_relations_type=shared_edge_types)
c_s = Classifier(args.n_hid,100,num_cls).to(device)
c_t = Classifier(args.n_hid,100,num_cls).to(device)
discriminator = Discriminator(args.n_hid,100,2).to(device)
if args.loadfile !=-1:
gnncheckpoint = torch.load(args.premodel_dir + '/gnn'+args.src_domain+args.tgt_domain+'_epoch_' + str(args.loadfile) + '.pth')
gnn.load_state_dict(gnncheckpoint)
c_scheckpoint = torch.load(args.premodel_dir + '/c_s'+args.src_domain+args.tgt_domain+'_epoch_' + str(args.loadfile) + '.pth')
c_s.load_state_dict(c_scheckpoint)
c_tcheckpoint = torch.load(args.premodel_dir + '/c_t'+args.src_domain+args.tgt_domain+'_epoch_' + str(args.loadfile) + '.pth')
c_t.load_state_dict(c_tcheckpoint)
discriminatorcheckpoint = torch.load(args.premodel_dir + '/discriminator'+args.src_domain+args.tgt_domain+'_epoch_' + str(args.loadfile) + '.pth')
discriminator.load_state_dict(discriminatorcheckpoint)
epoch=0
acc,n_correct,n_total = test(gnn,c_s,c_t,discriminator,tgt_test_data,device,epoch,args)
print("TGT Test epoch =[{}/{}] Avg Acc = {:.5f} n_correct = {} n_total = {}".format(epoch+1,args.n_epoch,acc,n_correct,n_total))
acc,n_correct,n_total = test(gnn,c_s,c_t,discriminator,src_test_data,device,epoch,args)
print("SRC Test epoch =[{}/{}] Avg Acc = {:.5f} n_correct = {} n_total = {}".format(epoch+1,args.n_epoch,acc,n_correct,n_total))
raise RuntimeError
the test function is list below
def test(gnn,c_s,c_t,discriminator,data_list,device,epoch,args):
gnn.eval()
c_s.eval()
c_t.eval()
discriminator.eval()
with torch.no_grad():
n_total = 0
n_correct = 0
i = 0
while i < args.n_batch:
for node_feature, node_type, edge_index, edge_type, x_ids, batch_ylabel in data_list[i]:
tgt_node_feature = node_feature.to(device)
tgt_node_type = node_type.to(device)
tgt_edge_index = edge_index.to(device)
tgt_edge_type = edge_type.to(device)
tgt_x_ids = x_ids
tgt_batch_ylabel = batch_ylabel.to(device).squeeze()
node_rep_tgt = gnn.forward(tgt_node_feature, tgt_node_type,tgt_edge_index, tgt_edge_type) #TODO
tgt_class_pred = c_s(node_rep_tgt[tgt_x_ids])#256,4
tgt_class_pred = torch.argmax(tgt_class_pred,dim=-1)
n_correct += tgt_class_pred.eq(tgt_batch_ylabel.data.view_as(tgt_class_pred)).cpu().sum()
n_total += len(tgt_x_ids)
i += 1
acc = n_correct * 1.0 / n_total
return acc.numpy(),n_correct,n_total
the result changes like this
TGT Test epoch =[1/900] Avg Acc = 0.16956 n_correct = 1389 n_total = 8192
SRC Test epoch =[1/900] Avg Acc = 0.91479 n_correct = 7494 n_total = 8192
---------------------------------------------------------------------
TGT Test epoch =[1/900] Avg Acc = 0.17188 n_correct = 1408 n_total = 8192
SRC Test epoch =[1/900] Avg Acc = 0.96106 n_correct = 7873 n_total = 8192