Loss begins at 219.7 or so and never budges.
There is nothing wrong with my training data or my training loop. I can literally substitute my AttnAggregator module below for a SAGEconv
version, and training proceeds just fine. If I substitute GATconv()
layers for the GATv2conv()
layers, this frozen loss also occurs. When I print out the final value of out
from the bottom for forward()
and it gives perfectly coherent numbers.
What am I missing?
out tensor([0.1702, 0.1489, 0.1702, 0.1702, 0.1702, 0.1702, 0.1667, 0.1667, 0.1667,
0.1667, 0.1667, 0.1667], grad_fn=<CatBackward0>)
epoch 126 , loss 219.74465942382812
epoch 127 , loss 219.7310791015625
epoch 128 , loss 219.74839782714844
epoch 129 , loss 219.73077392578125
epoch 130 , loss 219.73785400390625
epoch 131 , loss 219.73355102539062
epoch 132 , loss 219.74407958984375
epoch 133 , loss 219.75140380859375
epoch 134 , loss 219.79591369628906
epoch 135 , loss 219.70147705078125
epoch 136 , loss 219.73855590820312
epoch 137 , loss 219.72088623046875
epoch 138 , loss 219.7546844482422
epoch 139 , loss 219.77212524414062
training time elapsed mins 0.7
over 139 epochs
= =
import torch.nn as nn
import torch_geometric as pyg
from torch import flatten
from torch import concat
class AttnAggregator( nn.Module ):
def __init__(self ):
super(AttnAggregator, self).__init__()
self.dropout = 0.004
self.trainPhase = True
self.obj_goal_split = None
self.conv1 = pyg.nn.GATv2Conv( 5, 64, heads=8, dropout=0.0 )
self.conv2 = pyg.nn.GATv2Conv( 64*self.in_heads, 64, concat=False,heads=1, dropout=0.01 )
self.SOFTo = nn.Softmax(dim=0)
self.SOFTg = nn.Softmax(dim=0)
def forward(self, x, edgeIndex):
assert( not (self.obj_goal_split is None) )
out = self.conv1(x,edgeIndex)
out = self.conv2(out,edgeIndex)
out = flatten(out)
sp = self.obj_goal_split
out = concat( (self.SOFTo(out[sp[0]:sp[1]]), self.SOFTg(out[sp[2]: ])),0)
return ( out )
def set_phase(self, isTrainPhase ):
self.trainPhase = isTrainPhase
def set_object_goal_split(self, split ):
self.obj_goal_split = split
===
### Adam optimizer ###
params = list( GNNmodel.parameters() )
optimizer = torch.optim.Adam( params, lr=0.005, weight_decay=5.0e-4 )
optimizer.zero_grad()