Use output of clustering model


I have a clustering model, for instance, a simple Multi-Layer Perceptron (MLP) that provides scores for each cluster. I utilize these scores to assign each node to a cluster. Subsequently, I use these clusters as input for another model, which computes the primary loss required for training the clustering model. However, a challenge arises when I use the argmax function to assign nodes to clusters, as it does not compute gradients, preventing the clustering model from updating.

What can I do to address this issue?

Welcome to the forums!

Can you write a short code example that duplicates the issue? For example:

model = nn.Linear(15, 1)

optimizer = torch.optim.SGD(model.parameters(), lr = 0.001)  # include an optimizer since we need to track gradients

criterion = nn.MSELoss()

dummy_inputs = torch.rand((10, 15)) #make something to put into the model

outputs = model(dummy_inputs)

# etc.
1 Like

Thank you.

class dummy_1(nn.Module):
    def __init__(self, input_dim, num_class):
        super(dummy_1, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.Linear(64, num_class),

    def forward(self, x):
        x = self.encoder(x)
        return F.softmax(scores, dim=1)

optimizer = torch.optim.SGD(model_1.parameters(), lr = 0.001)

for epoch in range(10):

def Compute_Loss(scores):
    predicted_class = torch.argmax(scores, dim=1)
    clusters = {}

    # Iterate through nodes and assign them to clusters based on predicted labels
    for node, label in enumerate(predicted_class):
        if label not in clusters:
            clusters[label] = [node]
        for key in clusters.keys():
            model_2=dummy_2() # Compute another loss 
        return loss

In order to make a determination of loss, you either need pre-designated targets(supervised learning) or some sort of objective function(reinforcement learning). An objective function should designate some objective to maximize or minimize, such as maximizing a high score, or minimizing damage, etc.

From what I can tell, you are attempting to use model_1 to make some choice. Then you are using model_2 to calculate a loss value based on that choice. But it seems you’re lacking an objective function to train model_2 and some way to additionally pass those gradients to model_1.

Here is a tutorial with an actor/critic architecture that uses PPO: Reinforcement Learning (PPO) with TorchRL Tutorial — PyTorch Tutorials 2.1.0+cu121 documentation

By the way, I was unable to get your code to run as there are aspects of your code defined outside of it.

1 Like

Thank you.
I have an objective function to train model_2. after training model_2, it computes/outputs a number which model_1 should be updated to maximize that number. and I think because of argmax, gradients can’t be passed from model_2 to model_1

What kind of number is coming from model_2? Is it the class prediction?

If that’s the case, then you just need to make a second loss function with BCEWithLogitsLoss — PyTorch 2.1 documentation

Remove the final sigmoid on model_1 and use the raw logits, which represent a probability distribution. The above loss function takes as an argument the probability distribution for choice of class from model_1, and the integer class number, presumably from model_2. And that will give you a loss you can use to backpropagate gradients with.

Well, the second model is AutoEncoder which returns mul(z,z.T), and thus I think the approach you proposed does not work.

So if model_2 is NOT predicting the class, is model_2 predicting what the weights of model_1 should be? What exactly is the output of model_2?

I think this makes it clear.

class model_2(nn.Module):
    def __init__(self):

        self.encoder = Encoder()


    def forward(self, X,adj_):
        z= self.encoder(X,adj_)
        return torch.matmul(z,z.T).sum()

Where is Encoder() defined?

What is X? What is adj_?

Where is Encoder() defined?

Is it important? assume it’s just MLP.

What is X? What is adj_?

Assume both X and adj_ are Cluster which you’ve seen in the first code i provided.