I am Facing issue in supervising my VAE.Usually In VAE, it is an unsupervised approach with BCE logits and reconstruction loss. We have also added BCE loss on an true_label.
My model looks something like this:
class GCNModelVAE(nn.Module):
def __init__(self, input_feat_dim, hidden_dim1, hidden_dim2,num_classes, dropout):
super(GCNModelVAE, self).__init__()
self.gc1 = GraphConvolution(input_feat_dim, hidden_dim1, dropout, act=F.relu)
self.gc2 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
self.gc3 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
self.dc = InnerProductDecoder(dropout, act=lambda x: x)
#classification feed forward network
self.fc5 = nn.Linear(hidden_dim2, 64)
self.fc6 = nn.Linear(64, num_classes)
self.relu1 = nn.ReLU()
self.softmax=nn.Softmax(dim=1)
self.sigmoid = nn.Sigmoid()
def encode(self, x, adj):
hidden1 = self.gc1(x, adj)
return self.gc2(hidden1, adj), self.gc3(hidden1, adj)
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu
def forward(self, x, adj):
mu, logvar = self.encode(x, adj)
z = self.reparameterize(mu, logvar)
return self.dc(z), mu, logvar,self.sigmoid(self.fc6(self.relu1(self.fc5(z))))
and I have defined the loss as :
label_loss = nn.BCELoss()
def loss_function(preds, labels, mu, logvar,true_label,predicted_label, n_nodes, norm, pos_weight):
cost = norm * F.binary_cross_entropy_with_logits(preds, labels, pos_weight=pos_weight)
c_loss=label_loss(predicted_label, true_label)
KLD = -0.5 / n_nodes * torch.mean(torch.sum(
1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 1))
return KLD + cost+ c_loss
So here cost is the reconstruction loss which is working perfectly, but I am not able to learn from c_loss which is basically a categorical loss with true_label and predicted label
for example:
true_label=[0,0,0,0,1,0]
predicted_label=[0.4466, 0.4002, 0.4182, 0.4569, 0.4530, 0.5556]
I suppose I am doing something wrong in the classification model(or in the activation functions of relu or sigmoid) or the c_loss in loss function. any help will be really helpful