Cross Entropy loss in Supervised VAE

I am Facing issue in supervising my VAE.Usually In VAE, it is an unsupervised approach with BCE logits and reconstruction loss. We have also added BCE loss on an true_label.
My model looks something like this:

class GCNModelVAE(nn.Module):
    def __init__(self, input_feat_dim, hidden_dim1, hidden_dim2,num_classes, dropout):
        super(GCNModelVAE, self).__init__()
        self.gc1 = GraphConvolution(input_feat_dim, hidden_dim1, dropout, act=F.relu)
        self.gc2 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
        self.gc3 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
        self.dc = InnerProductDecoder(dropout, act=lambda x: x)
        
        #classification feed forward network
        self.fc5 = nn.Linear(hidden_dim2, 64)
        self.fc6 = nn.Linear(64, num_classes)
        self.relu1 = nn.ReLU()
        self.softmax=nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()

    def encode(self, x, adj):
        hidden1 = self.gc1(x, adj)
        return self.gc2(hidden1, adj), self.gc3(hidden1, adj)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x, adj):
        mu, logvar = self.encode(x, adj)
        z = self.reparameterize(mu, logvar)
        return self.dc(z), mu, logvar,self.sigmoid(self.fc6(self.relu1(self.fc5(z))))

and I have defined the loss as :

label_loss = nn.BCELoss()
def loss_function(preds, labels, mu, logvar,true_label,predicted_label, n_nodes, norm, pos_weight):
    cost = norm * F.binary_cross_entropy_with_logits(preds, labels, pos_weight=pos_weight)
    c_loss=label_loss(predicted_label, true_label)
    KLD = -0.5 / n_nodes * torch.mean(torch.sum(
        1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 1))
    return  KLD + cost+ c_loss

So here cost is the reconstruction loss which is working perfectly, but I am not able to learn from c_loss which is basically a categorical loss with true_label and predicted label
for example:
true_label=[0,0,0,0,1,0]
predicted_label=[0.4466, 0.4002, 0.4182, 0.4569, 0.4530, 0.5556]
I suppose I am doing something wrong in the classification model(or in the activation functions of relu or sigmoid) or the c_loss in loss function. any help will be really helpful

You don’t account for z density in the second decoding network, I think you need to use normalizing flows (and ELBO loss) there, to counteract the sampling noise. Or maybe decoding mu will work.

thanks @googlebot, If possible, can you please suggest the changes in the code?
whether the model is incorrect or the loss definition is incorrect?

I suspect that this doesn’t work, because z is too noisy.

If you use mu or mu.detach() there, that should work like a usual MLE predictor.

thanks again @googlebot, but I even tried to decode the mu in place of z too, there is no increase in performance. if possible can you tell me how to implement ELBO loss you mentioned above, it will really helpful if you can show where to implement in my file?
I am trying to Implement something like this**Deep Generative Model in https://github.com/wohlert/semi-supervised-pytorch/tree/master/examples/notebooks (though here they have used mnist dataset while I am working on Graph Neural Network)

Isn’t that odd? I see two scenarios: 1)encoder throws away information needed for classification 2) your two tasks are conflicting, wanting different rerpesentations. Basically, you should check if that classifier trains on its own (with c_loss as the only loss).

VAEs are already using ELBO loss, here it appears as a variant with closed form KL. If you’d transform z’s distribution with normalizing flows, then you’d need another form. But that’s totally different decoder design, requiring different layers. And I’m not sure this will help you, if classifier doesn’t even work in non-variational mode.

Note that they start from a simpler model, where aux classifier is applied to non-encoded inputs. Second model uses “learned encoder as a feature extractor” - but note that they pretrain and freeze the auto-encoder!

And the code looks more sophisticated, for example I see the use of importance sampling. There is also some nice illustrative code in that repo for elbo, flows, etc.

yes, It seems The c_loss is not working in itself.

class Classifier(nn.Module):
    def __init__(self, input_feat_dim, hidden_dim1,num_classes):
        """
        Single hidden layer classifier
        with softmax output.
        """
        super(Classifier, self).__init__()
        self.fc5 = nn.Linear(input_feat_dim, hidden_dim1)
        self.fc6 = nn.Linear(hidden_dim1, num_classes)
        self.relu1 = nn.ReLU()
        self.softmax=nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = F.relu(self.fc5(x))
        x = F.softmax(self.fc6(x), dim=-1)
        return x

class GCNModelVAE(nn.Module):
    def __init__(self, input_feat_dim, hidden_dim1, hidden_dim2,num_classes, dropout):
        super(GCNModelVAE, self).__init__()
        self.gc1 = GraphConvolution(input_feat_dim, hidden_dim1, dropout, act=F.relu)
        self.gc2 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
        self.gc3 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
        self.dc = InnerProductDecoder(dropout, act=lambda x: x)
        self.classifier = Classifier(hidden_dim2,64,num_classes)

    def encode(self, x, adj):
        hidden1 = self.gc1(x, adj)
        return self.gc2(hidden1, adj), self.gc3(hidden1, adj)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x, adj):
        mu, logvar = self.encode(x, adj)
        z = self.reparameterize(mu, logvar)
        return self.dc(z), mu, logvar,self.classifier(mu)


def loss_function(preds, labels, mu, logvar,true_label,predicted_label, n_nodes, norm, pos_weight):
    c_loss=F.binary_cross_entropy_with_logits(predicted_label, true_label)
    return c_loss

for epoch in range(args.epochs):
        t = time.time()
        model.train()
        true_label = Onehotencoder(True_labels,num_classes)
        optimizer.zero_grad()
        recovered, mu, logvar,predicted_label = model(features, adj_norm)
        loss = loss_function(preds=recovered, labels=adj_label,
                             mu=mu, logvar=logvar,true_label=true_label,predicted_label=predicted_label, n_nodes=n_nodes,
                             norm=norm, pos_weight=pos_weight)
        
        loss.backward()
        cur_loss = loss.item()
        optimizer.step()
        _, idx = torch.max(predicted_label, dim=1)
        print('Classification accuracy is ',(idx.numpy()==True_labels).sum()/len(True_labels))
        hidden_emb = mu.data.numpy()
        roc_curr, ap_curr = get_roc_score(hidden_emb, adj_orig, val_edges, val_edges_false)

        print("Epoch:", '%04d' % (epoch + 1), 
              "train_loss=", "{:.5f}".format(cur_loss),
              "val_ap=", "{:.5f}".format(ap_curr),
              "time=", "{:.5f}".format(time.time() - t)
              )

Loss is not going down and classification accuracy is fluctuating. If possible, can you help me in the above classification.

I missed that loss is indeed incorrect. That one is for raw real scores (and two classes), look at F.cross_entropy or F.nll_loss, and don’t use softmax as they do it themselves.

1 Like

thanks again @googlebot, but I have done the above-said things. If possible, can you please have a look at it once? I don’t know why this classification is not working as good as I was expecting. Maybe this is the best result(Classification accuracy of 65%) I can get on this task, because it is a complex model.Thanks

x = F.relu(self.fc6(x))

this relu is incorrect, as that’s your output layer. If that’s not enough, try using torch.selu activations (or batch normalization) and no dropout.