Here is the loss code:
def our_loss(output, target):
global S
global batchSize
global _EPSILON
y_pred = torch.clamp(output, _EPSILON, 1.0-_EPSILON)
out = -(target * torch.log(y_pred) + (1.0 - target) * torch.log(1.0 - y_pred))
out = torch.flatten(out)
tmp = torch.mv(S.float() , out.float())
loss = torch.dot(out.t(), tmp)
return loss
where S is a RBF generated Gram matrix of the input patterns.