Thanks @ptrblck. I omitted this initially as it is a lot of code, sorry about that. Really appreciate your time.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import *
import pandas as pd
import numpy as np
from sklearn.neighbors import NearestNeighbors
import warnings
warnings.filterwarnings("ignore")
M = 50
data = torch.cat((MultivariateNormal(-7 * torch.ones(2), torch.eye(2)).sample([M]),
MultivariateNormal(torch.tensor([-3., 0]), torch.eye(2)).sample([M]),
MultivariateNormal(torch.tensor([0, 3.]), torch.eye(2)).sample([M]),
MultivariateNormal(torch.tensor([-8, 10.]), torch.eye(2)).sample([M]),
MultivariateNormal(torch.tensor([10., -2]), torch.eye(2)).sample([M]),
MultivariateNormal(7 * torch.ones(2), torch.eye(2)).sample([M])))
N, D = data.shape
labels = torch.tensor([1 if i < 3*M else 0 for i in range(N)])
class Siamese(nn.Module):
def __init__(self, input_size):
super(Siamese, self).__init__()
self.hidden = nn.Linear(input_size, 2)
self.out = nn.Linear(2, 2)
def forward_one(self, x):
x = self.hidden(x)
x = self.out(x)
return x
def forward(self, x1, x2):
out1 = self.forward_one(x1)
out2 = self.forward_one(x2)
dis = torch.norm(out1 - out2, dim=1)
return dis
def SiameseLoss(outputs, labels):
'''
outputs: batch size
labels: batch size
Returns:
loss
'''
samepairslen = outputs[labels == 1].shape[0]
diffpairslen = outputs[labels == 0].shape[0]
if diffpairslen == 0:
return (samepairslen ** -1) * torch.sum(outputs[labels == 1])
elif samepairslen == 0:
return - (diffpairslen ** -1) * torch.sum(outputs[labels == 0])
else:
return (samepairslen ** -1) * torch.sum(outputs[labels == 1]) - (diffpairslen ** -1) * torch.sum(outputs[labels == 0])
def nnPairIdxsGenerator(data, labels, batch_size, k):
'''
Returns:
pairs: batch_size * k x 2
pairs of INDICES of datapoints
'''
N, D = data.shape
rindxs = torch.randint(low=0, high=N, size=(batch_size,))
_, nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(data.detach()).kneighbors(data[rindxs])
pairs = np.zeros((batch_size * k, 2))
pair_labels = torch.zeros(batch_size * k)
for idx in range(0, batch_size):
for jdx in range(k):
pairs[idx*k + jdx, 0] = nbrs[idx, 0]
pairs[idx*k + jdx, 1] = nbrs[idx, jdx+1]
pair_labels[idx*k + jdx] = (labels[nbrs[idx, 0]] == labels[nbrs[idx, jdx+1]]).type(torch.int)
return pairs, pair_labels
from sklearn.cluster import KMeans
def betaln(a, b):
return torch.lgamma(a) + torch.lgamma(b) - torch.lgamma(a+b)
def estimate_parameters(data, assignments):
assignment_counts = torch.sum(assignments, dim=0) + 1e-15
means = torch.mm(assignments.T, data) / assignment_counts[:, None]
sq_means = 2 * means * \
torch.mm(assignments.T, data) / assignment_counts[:, None]
avg_cov = torch.mm(assignments.T, data * data) / \
assignment_counts[:, None]
covariances = avg_cov - sq_means + means**2 + 1e-7
return assignment_counts, means, covariances
def estimate_weights(assignment_counts):
# assignment counts shape = (T,)
stack_cumsum = torch.cat((torch.flip(torch.cumsum(torch.flip(assignment_counts, dims=(0,)), dim=0), dims=(0,))[1:], torch.zeros(1))) # ???
weight_concentration = (1. + assignment_counts, (alpha +
stack_cumsum))
return weight_concentration
def estimate_means(data, assignment_counts, means):
new_mean_prec = 1. + assignment_counts
means_unnorm = (1. * torch.mean(data, dim=0) +
assignment_counts[:, None] * means)
new_means = (means_unnorm / new_mean_prec[:, None])
return new_means, new_mean_prec
def estimate_precisions(data, assignment_counts, means, covariances):
new_dof = D + assignment_counts
diff = means - torch.mean(data, dim=0)
new_covariances = (gamma *
torch.var(data, dim=0) +
assignment_counts[:, None] *
(covariances +
(1. /
mean_prec)[:, None] *
torch.mul(diff, diff)))
new_covariances /= new_dof[:, None]
new_precisions = 1. / torch.sqrt(new_covariances)
return new_covariances, new_precisions, new_dof
def initialization(data):
global mean_prec
global dof
label_matrix = torch.zeros(N, T)
assignments = KMeans(n_clusters=T, n_init=1).fit(data.detach().numpy()).labels_
label_matrix[torch.arange(N), assignments] = 1
counts, means, covariances = estimate_parameters(data, label_matrix)
weight_conc = estimate_weights(counts)
new_means, mean_prec = estimate_means(data, counts, means)
new_covariances, new_precisions, dof = estimate_precisions(
data, counts, means, covariances)
return weight_conc, new_means, new_covariances, new_precisions
def compute_expectations(data, means, precisions_pd, concentrations):
def log_determinant(x): return torch.sum(torch.log(x), dim=1)
# Estimate log Gaussian probability (requires means, Cholesky precision
# matrix, data, dof)
precisions = precisions_pd ** 2
sum_means_sq = torch.sum((means ** 2 * precisions), dim=1)
log_prob = (
sum_means_sq -
2. *
torch.mm(
data,
(means *
precisions).T) +
torch.mm(
data ** 2,
precisions.T))
log_prob_g = (-.5 * (D * np.log(2 * np.pi) + log_prob) +
log_determinant(precisions_pd)) - (.5 * D * np.log(dof))
# Estimate total log probability (requires data, dof)
new_dofD = (dof - torch.arange(0, D)[:, None])
log_prob_l = D * np.log(2.) + torch.sum(torch.digamma(.5 *
new_dofD), dim=0)
log_prob_total = log_prob_g + 0.5 * (log_prob_l - D / mean_prec)
# Estimate log weights
digamma_sum = torch.digamma(concentrations[0] + concentrations[1])
log_weights = torch.digamma(concentrations[0]) - digamma_sum + torch.cat(
(torch.zeros(1), torch.cumsum(torch.digamma(concentrations[1]) - digamma_sum, dim=0)[:-1])) # ??? stack
# Estimate log likelihoods
log_weighted_prob = log_weights + log_prob_total
softmax_log_weighted_prob = torch.logsumexp(log_weighted_prob, dim=1)
# with np.errstate(under='ignore'):
log_likelihoods = log_weighted_prob - \
softmax_log_weighted_prob[:, None]
mean_log_weighted_prob = torch.mean(softmax_log_weighted_prob)
return mean_log_weighted_prob, log_likelihoods # shapes are also NxD
def max_probabilities(data, log_likelihoods):
counts, means, covariances = estimate_parameters(
data, torch.exp(log_likelihoods))
weight_conc = estimate_weights(counts)
new_means, mean_prec = estimate_means(data, counts, means)
new_covariances, new_precisions, dof = estimate_precisions(
data, counts, means, covariances)
return weight_conc, new_means, new_covariances, new_precisions
def elbo(log_likelihoods, precisions, weight_concentrations):
log_determinant = (
torch.sum(
torch.log(precisions),
dim=1) -
0.5 *
D *
torch.log(dof))
log_norm = -(dof * log_determinant +
dof * D * .5 * np.log(2.) +
torch.sum(torch.tensor(torch.lgamma(.5 * (dof - torch.arange(D)[:, None]))), dim=0))
log_norm = torch.sum(log_norm)
log_norm_weight = - \
torch.sum(betaln(weight_concentrations[0], weight_concentrations[1]))
return (-torch.sum(torch.exp(log_likelihoods) * log_likelihoods) -
log_norm - log_norm_weight - 0.5 * D * torch.sum(torch.log(mean_prec)))
N, D = data.shape
# Hyperparameters
T = 4
n_iter = 500
batch_size = 10
mean_prec = 1.
dof = D
gamma = 1.5
alpha = 1.5
net = Siamese(D)
net_optim = torch.optim.Adam(net.parameters(), lr=0.05, weight_decay=1)
concentrations, means, covariances, precisions = initialization(net.forward_one(data))
for i in range(n_iter):
net_optim.zero_grad()
# pairs, pair_labels = pairGenerator(data, labels, batch_size)
pairs, pair_labels = nnPairIdxsGenerator(data, labels, batch_size, 5)
embedding = net.forward_one(data)
outputs = net(embedding[pairs[:, 0]], embedding[pairs[:, 1]])
mean_lwp, log_likelihoods = compute_expectations(
embedding, means, precisions, concentrations)
concentrations, means, covariances, precisions = max_probabilities(
embedding, log_likelihoods)
dpmm_loss = elbo(log_likelihoods, precisions, concentrations)
net_loss = SiameseLoss(outputs, pair_labels)
net_loss.backward(retain_graph=True)
dpmm_loss.backward(retain_graph=True)
net_optim.step()
with torch.no_grad():
_, log_likelihoods = compute_expectations(net.forward_one(data), means, precisions, concentrations)
assignments = log_likelihoods.argmax(axis=1)