I have a clustering model, for instance, a simple Multi-Layer Perceptron (MLP) that provides scores for each cluster. I utilize these scores to assign each node to a cluster. Subsequently, I use these clusters as input for another model, which computes the primary loss required for training the clustering model. However, a challenge arises when I use the argmax function to assign nodes to clusters, as it does not compute gradients, preventing the clustering model from updating.

Can you write a short code example that duplicates the issue? For example:

model = nn.Linear(15, 1)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.001) # include an optimizer since we need to track gradients
criterion = nn.MSELoss()
dummy_inputs = torch.rand((10, 15)) #make something to put into the model
outputs = model(dummy_inputs)
# etc.

class dummy_1(nn.Module):
def __init__(self, input_dim, num_class):
super(dummy_1, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, num_class),
nn.ReLU()
)
def forward(self, x):
x = self.encoder(x)
return F.softmax(scores, dim=1)
model_1=dummy_1(128,num_class=6)
optimizer = torch.optim.SGD(model_1.parameters(), lr = 0.001)
for epoch in range(10):
scores=model_1(input)
loss=Compute_Loss(scores)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def Compute_Loss(scores):
predicted_class = torch.argmax(scores, dim=1)
clusters = {}
# Iterate through nodes and assign them to clusters based on predicted labels
for node, label in enumerate(predicted_class):
label=label.item()
if label not in clusters:
clusters[label] = [node]
else:
clusters[label].append(node)
loss=0.
for key in clusters.keys():
Cluster=Dataset_pyG.cuda().subgraph(torch.tensor(clusters[key]).cuda())
model_2=dummy_2() # Compute another loss
loss+=model_2(Cluster)
return loss

In order to make a determination of loss, you either need pre-designated targets(supervised learning) or some sort of objective function(reinforcement learning). An objective function should designate some objective to maximize or minimize, such as maximizing a high score, or minimizing damage, etc.

From what I can tell, you are attempting to use model_1 to make some choice. Then you are using model_2 to calculate a loss value based on that choice. But it seems you’re lacking an objective function to train model_2 and some way to additionally pass those gradients to model_1.

Thank you.
I have an objective function to train model_2. after training model_2, it computes/outputs a number which model_1 should be updated to maximize that number. and I think because of argmax, gradients can’t be passed from model_2 to model_1

Remove the final sigmoid on model_1 and use the raw logits, which represent a probability distribution. The above loss function takes as an argument the probability distribution for choice of class from model_1, and the integer class number, presumably from model_2. And that will give you a loss you can use to backpropagate gradients with.