Hello all, I am trying to train a GAT implementation on data which has already been prepared for use in a GCN, so I expect the format to be the same. However, the issue is that the loss per epoch remains small and tends to be stuck quite quickly.
I’ve been reading other posts suggesting normalising data, but since no normalisation was used for the GCN implementation I"m not too sure it’d be useful although I’l give it a try soon.
I’ve made some small changes to the loss calculation from this pyGAT code. Does there seem to be any error in what I’ve changed?
class GAT(nn.Module):
def __init__(self, params):
nfeat = params['input_dim']
nheads = params['num_heads']
nhid = params['hidden_dim']
# inpdropout = nn.Dropout(params['input_dropout'])
# dropout = nn.Dropout(params['dropout'])
dropout = params['dropout']
self.nclass = params['num_labels']
alpha = params['alpha']
# manually set here again
# self.num_gat_layer = params['num_gat_layer']
"""Dense version of GAT."""
super(GAT, self).__init__()
self.dropout = dropout
self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
for i, attention in enumerate(self.attentions):
self.add_module('attention_{}'.format(i), attention)
self.out_att = GraphAttentionLayer(nhid * nheads, self.nclass, dropout=dropout, alpha=alpha, concat=False)
def forward(self, features, adjacency_matrix, mask=None, labels=None, flag="Train"):
all_hidden_states = (features,)
features = F.dropout(features, self.dropout, training=self.training)
features = torch.cat([att(features, adjacency_matrix) for att in self.attentions], dim=1)
features = F.dropout(features, self.dropout, training=self.training)
features = F.elu(self.out_att(features, adjacency_matrix))
logits = features
preds = F.log_softmax(features, dim=1)
_, preds = torch.max(preds, dim=-1)
# outputs = (preds, all_hidden_states)
outputs = (preds, all_hidden_states)
if flag.upper() == "TRAIN":
loss_fct = CrossEntropyLoss(ignore_index=-1, reduction='none')
# loss_fct = nn.NLLLoss(ignore_index=-1, reduction='none')
loss = loss_fct(logits.view(-1, self.nclass), labels.view(-1))
mask = mask.float()
mask = mask / mask.mean()
loss *= mask
loss = loss.mean()
outputs = (loss,) + outputs
return outputs
class GraphAttentionLayer(nn.Module):
"""
Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
"""
def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(GraphAttentionLayer, self).__init__()
self.dropout = dropout
self.in_features = in_features
self.out_features = out_features
self.alpha = alpha
self.concat = concat
self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)
self.leakyrelu = nn.LeakyReLU(self.alpha)
def forward(self, h, adj):
Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)
e = self._prepare_attentional_mechanism_input(Wh)
zero_vec = -9e15*torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec)
attention = F.softmax(attention, dim=1)
attention = F.dropout(attention, self.dropout, training=self.training)
h_prime = torch.matmul(attention, Wh)
if self.concat:
return F.elu(h_prime)
else:
return h_prime
def _prepare_attentional_mechanism_input(self, Wh):
# Wh.shape (N, out_feature)
# self.a.shape (2 * out_feature, 1)
# Wh1&2.shape (N, 1)
# e.shape (N, N)
Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])
Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])
# broadcast add
e = Wh1 + Wh2.T
return self.leakyrelu(e)
def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
P.S. Using a log loss calculation does not seem to help either as loss gets increasingly negatively large…
The training output:
Epoch: 1% 1/160 [00:05<15:53, 5.99s/it]Epoch: 1, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 5.7988, Avg_loss: 3.7230
Epoch: 1% 2/160 [00:07<08:36, 3.27s/it]Epoch: 2, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 12.7053, Avg_loss: 6.7171
Epoch: 2% 3/160 [00:08<06:18, 2.41s/it]Epoch: 3, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 2.7875, Avg_loss: 5.7347
Epoch: 2% 4/160 [00:10<05:12, 2.00s/it]Epoch: 4, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 9.4120, Avg_loss: 6.4702
Epoch: 3% 5/160 [00:11<04:24, 1.71s/it]Epoch: 5, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.9508, Avg_loss: 5.7169
Epoch: 4% 6/160 [00:12<03:51, 1.50s/it]Epoch: 6, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 10.7928, Avg_loss: 6.4421
Epoch: 4% 7/160 [00:13<03:30, 1.37s/it]Epoch: 7, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 12.9792, Avg_loss: 7.2592
Epoch: 5% 8/160 [00:14<03:15, 1.29s/it]Epoch: 8, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1368, Avg_loss: 6.5789
Epoch: 6% 9/160 [00:15<03:06, 1.23s/it]Epoch: 9, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 2.0879, Avg_loss: 6.1298
Epoch: 6% 10/160 [00:16<02:59, 1.19s/it]Epoch: 10, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1034, Avg_loss: 5.6729
Epoch: 7% 11/160 [00:17<02:53, 1.17s/it]Epoch: 11, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1384, Avg_loss: 5.2950
Epoch: 8% 12/160 [00:19<02:50, 1.15s/it]Epoch: 12, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 2.7059, Avg_loss: 5.0959
Epoch: 8% 13/160 [00:20<02:47, 1.14s/it]Epoch: 13, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 3.4661, Avg_loss: 4.9795
Epoch: 9% 14/160 [00:21<02:52, 1.18s/it]Epoch: 14, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1119, Avg_loss: 4.7216
Epoch: 9% 15/160 [00:22<02:56, 1.21s/it]Epoch: 15, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.0986, Avg_loss: 4.4952
Epoch: 10% 16/160 [00:24<02:58, 1.24s/it]Epoch: 16, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.0986, Avg_loss: 4.2954
Epoch: 11% 17/160 [00:25<03:04, 1.29s/it]Epoch: 17, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.7396, Avg_loss: 4.1534
Epoch: 11% 18/160 [00:26<03:08, 1.33s/it]Epoch: 18, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1108, Avg_loss: 3.9933
Epoch: 12% 19/160 [00:28<03:03, 1.30s/it]Epoch: 19, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1016, Avg_loss: 3.8487
Epoch: 12% 20/160 [00:29<02:54, 1.25s/it]Epoch: 20, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1091, Avg_loss: 3.7182
Epoch: 13% 21/160 [00:30<02:48, 1.21s/it]Epoch: 21, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1154, Avg_loss: 3.5999
Epoch: 14% 22/160 [00:31<02:43, 1.19s/it]Epoch: 22, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1008, Avg_loss: 3.4913
Epoch: 14% 23/160 [00:32<02:39, 1.16s/it]Epoch: 23, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 2.5263, Avg_loss: 3.4510
Epoch: 15% 24/160 [00:33<02:36, 1.15s/it]Epoch: 24, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.2944, Avg_loss: 3.3648
Epoch: 16% 25/160 [00:34<02:33, 1.14s/it]Epoch: 25, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1014, Avg_loss: 3.2777
Epoch: 16% 26/160 [00:35<02:32, 1.14s/it]Epoch: 26, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.4125, Avg_loss: 3.2086
Epoch: 17% 27/160 [00:37<02:30, 1.13s/it]Epoch: 27, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1026, Avg_loss: 3.1334
Epoch: 18% 28/160 [00:38<02:35, 1.18s/it]Epoch: 28, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1000, Avg_loss: 3.0633
Epoch: 18% 29/160 [00:39<02:39, 1.22s/it]Epoch: 29, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1764, Avg_loss: 3.0004
Epoch: 19% 30/160 [00:41<02:43, 1.26s/it]Epoch: 30, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.6934, Avg_loss: 2.9582
Epoch: 19% 31/160 [00:42<02:48, 1.30s/it]Epoch: 31, Train Acc: 0.2822, Dev Acc: 0.2900, Test Acc: 0.2550, Train F1: 0.1467, Dev F1: 0.1499, Test F1: 0.1355, Loss: 1.1175, Avg_loss: 2.9007
Any help is much appreciated, thank you