I created a neural network in PyTorch. My loss function is a weighted negative log-likelihood. The weights are determined by the output of my neural network and must be fixed. It means the weights depend on the output of the neural network but must be fixed so the network only calculates the gradient of log part and not the weights. Here is my code:
import torch
import torch.nn as nn
def extended_cumsum(x):
return torch.cat([torch.zeros(x.size()[1:3]).unsqueeze(0), torch.cumsum(x, 0)], 0)
def is_positive(x):
x_sign = torch.sign(x)
return torch.where(x_sign < 0, torch.zeros_like(x_sign), x_sign)
def coupling_transform(x1, x2, nn_output, k):
nn_output_normalized = torch.softmax(nn_output, 0)
bins_weights = torch.ones_like(nn_output_normalized)/k
knots_xs = extended_cumsum(bins_weights)
prev_bins = is_positive(x2.unsqueeze(0) - knots_xs)
current_bins = prev_bins[:-1,:,:] - prev_bins[1:,:,:]
q_sum = torch.sum(prev_bins[1:,:,:]*nn_output_normalized, 0)
q_current = torch.sum(current_bins*nn_output_normalized, 0)
w_sum = torch.sum(prev_bins[1:,:,:]*bins_weights, 0)
c_values = q_sum + k*(x2 - w_sum)*q_current
log_det = torch.log(torch.prod(k*q_current, 1))
return x1, c_values, log_det
def flipping_dims(n, d1, d2):
dims = []
for i in range(n):
if i%2 == 0:
dims.append(d1)
else:
dims.append(d2)
return dims
class Linear(nn.Module):
def __init__(self, d1, d2, k, hidden):
super().__init__()
self.d1, self.d2, self.k = d1, d2, k
self.net = nn.Sequential(nn.Linear(d1, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, k*d2))
def forward(self, x, log_prob, flip=False):
x1, x2 = x[:, :self.d1], x[:, self.d1:]
if flip:
x1, x2 = x2, x1
z1, z2, log_det = coupling_transform(x1, x2, torch.reshape(self.net(x1), (self.k, -1, self.d2)), self.k)
if flip:
z1, z2 = z2, z1
z = torch.cat([z1, z2], 1)
return z, log_prob - log_det
class stacked_Linear(nn.Module):
def __init__(self, d1, d2, k, hidden, n):
super().__init__()
self.layers = nn.ModuleList([
Linear(_d1, _d2, k, hidden=hidden) for _, _d1, _d2 in zip(range(n), flipping_dims(n, d1, d2), flipping_dims(n, d1, d2)[::-1])
])
self.flips = [True if i%2 else False for i in range(n)]
def forward(self, x, log_prev_prob):
for layer, f in zip(self.layers, self.flips):
x, log_prob = layer(x, log_prev_prob, flip=f)
log_prev_prob = log_prob
return x, log_prob
def f(x):
return torch.prod(torch.exp(x), 1)
def my_loss(weights, log_prob):
loss = -torch.mean(weights*log_prob)
return loss
model = stacked_Linear(3, 3, 32, 16, 4)
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
losses = []
for _ in range(100):
x = torch.rand(1000, 6)
optim.zero_grad()
z, log_prob = model(x, torch.zeros(1000))
f_values = f(z)
weights = f_values/torch.exp(log_prob)
loss = my_loss(weights, log_prob)
losses.append(loss)
loss.backward()
But the loss value doesn’t decrease and it doesn’t change if I fix x:
losses = []
x = torch.rand(1000, 6)
for _ in range(100):
optim.zero_grad()
z, log_prob = model(x, torch.zeros(1000))
f_values = f(z)
weights = f_values/torch.exp(log_prob)
loss = my_loss(weights, log_prob)
losses.append(loss)
loss.backward()
[tensor(-0.1160, grad_fn=<NegBackward>),
tensor(-0.1160, grad_fn=<NegBackward>),
tensor(-0.1160, grad_fn=<NegBackward>),
tensor(-0.1160, grad_fn=<NegBackward>),
tensor(-0.1160, grad_fn=<NegBackward>), ...]