Hi all, I’m trying to do learning on graphs and I’m facing an issue of my weights in my model not updating. I have simplified the model and it still seems like my weights are not updating. I’m wondering 2 things, (1) am I doing any operations that are breaking the computational graph? Because most of the tensors seem to have backward enabled. (2) Is there a good way to check how gradients are flowing? I’m currently just trying to print the parameters and watch if they get updated.
I even tried my model with a very large learning rate (1e-1) to see if the weights get updated, but no luck.
I deleted quite a few bits of code that did some graph learning as I wanted to simplify my model and isolate where the breakage is, so please pardon some excess unused variables. Here’s a snippet for my code:
class NeuraLBP(torch.nn.Module):
def __init__(self, d_in_channels, p_in_channels, out_channels, h_size, e_size, n_embed, **kwargs):
super(NeuraLBP, self).__init__(**kwargs)
# Feature sizes
# self.in_channels = in_channels
self.out_channels = out_channels
self.h_size = h_size
self.n_embed = n_embed
# Layers
# Embedding for dummy nodes
self.embed = torch.nn.Embedding(n_embed, out_channels)
# Translation of node features
self.dax_lin = torch.nn.Linear(d_in_channels, h_size, bias=False)
self.pax_lin = torch.nn.Linear(p_in_channels, h_size, bias=False)
# Custom MPNN from another paper
self.mpnn = MPNN(in_channels=h_size, out_channels=out_channels, edge_feature=e_size, hidden_size=h_size)
self.mpnn2 = MPNN(in_channels=out_channels, edge_feature=e_size, out_channels=out_channels, hidden_size=h_size)
self.mpnn3 = MPNN(in_channels=out_channels, edge_feature=e_size, out_channels=out_channels, hidden_size=h_size)
# Weight translation layer
self.weight_transform = torch.nn.Sequential(torch.nn.Linear(out_channels * 2, h_size * 4),
torch.nn.ReLU(),
torch.nn.Linear(h_size * 4, h_size * 4),
torch.nn.ReLU(),
torch.nn.Linear(h_size * 4, h_size * 2),
torch.nn.ReLU(),
torch.nn.Linear(h_size * 2, 1),
torch.nn.Tanh())
def reset_parameters(self):
return
def forward(self,
pax_feat,
dax_feat,
pax_embed,
assign_edge_index,
assign_edge_attr,
real_edges_original,
real_edge_attr,
sorted_edges,
sorted_features,
n_nodes):
x = torch.cat((self.pax_lin(pax_feat), self.dax_lin(dax_feat)), dim=0)
out = x
all_feat = torch.cat((out[:pax_feat.size(0), :], self.embed(pax_embed), out[pax_feat.size(0):, :]), dim=0)
wts = self.weight_transform(torch.cat((all_feat[sorted_edges[:, 0], :], all_feat[sorted_edges[:, 1], :]), dim=1))
wts = wts.view(n_nodes, n_nodes)
b_a, b_b = self.iter_simp_min_sum_batch_scatter(wts.unsqueeze(0), wts.t().unsqueeze(0), wts.unsqueeze(0).clone().detach(), 5)
wts = wts + b_b.permute(0, 2, 1).squeeze(0)
return wts
def iter_simp_min_sum_batch_scatter(self, m_alpha_beta, m_beta_alpha, weights, n_iter):
n = m_alpha_beta.size(1)
for _ in range(n_iter):
# Message passing
beta_alpha_maxes, beta_alpha_indices = torch.topk(m_beta_alpha, 2, dim=1)
m_alpha_beta_k = weights.permute(0, 2, 1) - beta_alpha_maxes[:, 0, :].unsqueeze(1)
m_alpha_beta_k = m_alpha_beta_k.scatter_add_(dim=1, index=beta_alpha_indices[:, 0, :].unsqueeze(1),
src=beta_alpha_maxes[:, 0, :].unsqueeze(1))
m_alpha_beta_k = m_alpha_beta_k.scatter_add_(dim=1, index=beta_alpha_indices[:, 0, :].unsqueeze(1),
src=beta_alpha_maxes[:, 1, :].unsqueeze(1) * -1).permute(0, 2,
1)
alpha_beta_maxes, alpha_beta_indices = torch.topk(m_alpha_beta, 2, dim=1)
m_beta_alpha_k = weights - alpha_beta_maxes[:, 0, :].unsqueeze(1)
m_beta_alpha_k = m_beta_alpha_k.scatter_add_(dim=1, index=alpha_beta_indices[:, 0, :].unsqueeze(1),
src=alpha_beta_maxes[:, 0, :].unsqueeze(1))
m_beta_alpha_k = m_beta_alpha_k.scatter_add_(dim=1, index=alpha_beta_indices[:, 0, :].unsqueeze(1),
src=alpha_beta_maxes[:, 1, :].unsqueeze(1) * -1).permute(0, 2,
1)
m_alpha_beta = m_alpha_beta_k
m_beta_alpha = m_beta_alpha_k
return m_alpha_beta, m_beta_alpha
And here’s my training loop:
model = NeuraLBP(d_in_channels=14, p_in_channels=6, out_channels=512, h_size=512, e_size=1, n_embed=1000)
model.reset_parameters()
model.train()
model = model.to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-1)
loss_fn = torch.nn.CrossEntropyLoss()
list_loss = []
list_acc = []
for i in range(5):
for d in d_loader:
optim.zero_grad()
pax_feat = d.pax_feat.to(device)
dax_feat = d.dax_feat.to(device)
pax_embed = d.pax_embed.to(device)
assign_edge_index = d.assign_edge_index.to(device)
assign_edge_attr = d.assign_edge_feat[:, 0].view(-1, 1).to(device) * -1.0
real_edges_original = d.real_edges_original.to(device)
real_edge_attr = d.real_edge_attr[:, 0].view(-1, 1).to(device) * -1.0
data.sorted_edges = d.sorted_edges.type(torch.LongTensor)
data.sorted_features = d.sorted_features.type(torch.FloatTensor)
sorted_edges = d.sorted_edges.to(device)
sorted_features = d.sorted_features.to(device) * -1.0
n_nodes = d.dax_feat.size(0)
out = model(pax_feat,
dax_feat,
pax_embed,
assign_edge_index,
assign_edge_attr,
real_edges_original,
real_edge_attr,
sorted_edges,
sorted_features,
n_nodes)
out = torch.softmax(out, dim=1)
# Organize the labels - CE labels
row, col = torch.where(d.lbl_mat > 0)
labels = torch.zeros((d.lbl_mat.size(0)), dtype=int) - 100
labels[row] = col
labels = labels.to(device)
loss = loss_fn(out, labels)
list_loss.append(float(loss))
print(float(loss))
with torch.no_grad():
acc = acc_func(out, labels)
list_acc.append(acc)
loss.backward()
optim.step()