Hi everyone. I am using this variational recurrent model to learn the dynamics of environment P(z_t|a_{<t},z_{<t})
. Even though, I initialize layers, after some iterations the weights and biases of decoder became NaN
values.
class VRNN(nn.Module):
@staticmethod
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Embedding):
nn.init.xavier_uniform_(m.weight)
elif isinstance(m,nn.GRU) or isinstance(m,nn.LSTM):
for name, param in m.named_parameters():
if "weight_ih" in name:
torch.nn.init.xavier_uniform_(param.data)
elif "weight_hh" in name:
torch.nn.init.orthogonal_(param.data)
elif "bias" in name:
param.data.fill_(0)
def __init__(self, u_dim, y_dim, h_dim, z_dim, n_layers, n_mixtures, device, bias=False):
super(VRNN_GMM, self).__init__()
self.y_dim = y_dim
self.u_dim = u_dim
self.h_dim = h_dim
self.z_dim = z_dim
self.n_layers = n_layers
self.n_mixtures = n_mixtures
self.device = device
# feature-extracting transformations (phi_y, phi_u and phi_z)
self.phi_y = nn.Sequential(
nn.Linear(self.y_dim, self.h_dim),
nn.ReLU(),
nn.Linear(self.h_dim, self.h_dim))
self.phi_u = nn.Sequential(
nn.Linear(self.u_dim, self.h_dim),
nn.ReLU(),
nn.Linear(self.h_dim, self.h_dim))
self.phi_z = nn.Sequential(
nn.Linear(self.z_dim, self.h_dim),
nn.ReLU(),
nn.Linear(self.h_dim, self.h_dim))
# encoder function (phi_enc) -> Inference
self.enc = nn.Sequential(
nn.Linear(self.h_dim + self.h_dim, self.h_dim),
nn.ReLU(),
nn.Linear(self.h_dim, self.h_dim),
nn.ReLU(), )
self.enc_mean = nn.Sequential(
nn.Linear(self.h_dim, self.z_dim))
self.enc_logvar = nn.Sequential(
nn.Linear(self.h_dim, self.z_dim),
nn.Softplus(), )
# prior function (phi_prior) -> Prior
self.prior = nn.Sequential(
nn.Linear(self.h_dim, self.h_dim),
nn.ReLU(),
nn.Linear(self.h_dim, self.h_dim),
nn.ReLU(), )
self.prior_mean = nn.Sequential(
nn.Linear(self.h_dim, self.z_dim))
self.prior_logvar = nn.Sequential(
nn.Linear(self.h_dim, self.z_dim),
nn.Softplus(), )
# decoder function (phi_dec) -> Generation
self.dec = nn.Sequential(
nn.Linear(self.h_dim + self.h_dim, self.h_dim),
nn.ReLU(),
nn.Linear(self.h_dim, self.h_dim),
nn.ReLU(), )
self.dec_mean = nn.Sequential(
nn.Linear(self.h_dim, self.y_dim * self.n_mixtures), )
self.dec_logvar = nn.Sequential(
nn.Linear(self.h_dim, self.y_dim * self.n_mixtures),
nn.Softplus(), )
self.dec_pi = nn.Sequential(
nn.Linear(self.h_dim, self.y_dim * self.n_mixtures),
nn.Softmax(dim=1)
)
# recurrence function (f_theta) -> Recurrence
self.rnn = nn.GRU(self.h_dim + self.h_dim, self.h_dim, self.n_layers, bias)
self.apply(self.weight_init)
def forward(self, u, y):
batch_size = y.size(0)
seq_len = y.shape[-1]
# allocation
loss = 0
# initialization
h = torch.zeros(self.n_layers, batch_size, self.h_dim, device=self.device)
# for all time steps
for t in range(seq_len):
# feature extraction: y_t
phi_y_t = self.phi_y(y[:, :, t])
# feature extraction: u_t
phi_u_t = self.phi_u(u[:, :, t])
# encoder: y_t, h_t -> z_t
enc_t = self.enc(torch.cat([phi_y_t, h[-1]], 1))
enc_mean_t = self.enc_mean(enc_t)
enc_logvar_t = self.enc_logvar(enc_t)
# prior: h_t -> z_t (for KLD loss)
prior_t = self.prior(h[-1])
prior_mean_t = self.prior_mean(prior_t)
prior_logvar_t = self.prior_logvar(prior_t)
# sampling and reparameterization: get a new z_t
z_t=self.reparametrization(enc_mean_t, enc_logvar_t)
# feature extraction: z_t
phi_z_t = self.phi_z(z_t)
# decoder: h_t, z_t -> y_t
dec_t = self.dec(torch.cat([phi_z_t, h[-1]], 1))
dec_mean_t = self.dec_mean(dec_t).view(batch_size, self.y_dim, self.n_mixtures)
dec_logvar_t = self.dec_logvar(dec_t).view(batch_size, self.y_dim, self.n_mixtures)
dec_pi_t = self.dec_pi(dec_t).view(batch_size, self.y_dim, self.n_mixtures)
for layer in self.dec_mean.children():
if isinstance(layer, nn.Linear):
print('mean weight:', layer.weight)
print('mean bias:', layer.bias)
for layer in self.dec_logvar.children():
if isinstance(layer, nn.Linear):
print('log_var weight:', layer.weight)
print('log_var bias:', layer.bias)
# recurrence: u_t+1, z_t -> h_t+1
_, h = self.rnn(torch.cat([phi_u_t, phi_z_t], 1).unsqueeze(0), h)
# computing the loss
KLD = self.kld_gauss(enc_mean_t, enc_logvar_t, prior_mean_t, prior_logvar_t)
loss_pred = self.loglikelihood_gmm(y[:, :, t], dec_mean_t, dec_logvar_t, dec_pi_t)
loss += - loss_pred + KLD
return loss
def reparametrization(self, mu, log_var):
# Reparameterization trick
var = torch.exp(log_var/ 2.)
eps = torch.randn_like(var)
return eps.mul(var).add_(mu)
def generate(self, u):
# get the batch size
batch_size = u.shape[0]
# length of the sequence to generate
seq_len = u.shape[-1]
# allocation
sample = torch.zeros(batch_size, self.y_dim, seq_len, device=self.device)
sample_mu = torch.zeros(batch_size, self.y_dim, seq_len, device=self.device)
sample_sigma = torch.zeros(batch_size, self.y_dim, seq_len, device=self.device)
h = torch.zeros(self.n_layers, batch_size, self.h_dim, device=self.device)
# for all time steps
for t in range(seq_len):
# feature extraction: u_t+1
phi_u_t = self.phi_u(u[:, :, t])
# prior: h_t -> z_t
prior_t = self.prior(h[-1])
prior_mean_t = self.prior_mean(prior_t)
prior_logvar_t = self.prior_logvar(prior_t)
# sampling and reparameterization: get new z_t
z_t = self.reparametrization(prior_mean_t, prior_logvar_t)
# feature extraction: z_t
phi_z_t = self.phi_z(z_t)
# decoder: z_t, h_t -> y_t
dec_t = self.dec(torch.cat([phi_z_t, h[-1]], 1))
dec_mean_t = self.dec_mean(dec_t).view(batch_size, self.y_dim, self.n_mixtures)
dec_logvar_t = self.dec_logvar(dec_t).view(batch_size, self.y_dim, self.n_mixtures)
dec_pi_t = self.dec_pi(dec_t).view(batch_size, self.y_dim, self.n_mixtures)
# store the samples
sample[:, :, t], sample_mu[:, :, t], sample_sigma[:, :, t] = self._reparameterized_sample_gmm(dec_mean_t,
dec_logvar_t,
dec_pi_t)
# recurrence: u_t+1, z_t -> h_t+1
_, h = self.rnn(torch.cat([phi_u_t, phi_z_t], 1).unsqueeze(0), h)
return sample, sample_mu, sample_sigma
def _reparameterized_sample_gmm(self, mu, logvar, pi):
# select the mixture indices
alpha = torch.distributions.Categorical(pi).sample()
# select the mixture indices
idx = logvar.shape[-1]
raveled_index = torch.arange(len(alpha.flatten()), device=self.device) * idx + alpha.flatten()
logvar_sel = logvar.flatten()[raveled_index]
mu_sel = mu.flatten()[raveled_index]
# get correct dimensions
logvar_sel = logvar_sel.view(logvar.shape[:-1])
mu_sel = mu_sel.view(mu.shape[:-1])
# resample
#temp = tdist.Normal(mu_sel, logvar_sel.exp().sqrt())
#sample = tdist.Normal.rsample(temp)
sample = self.reparametrization(mu_sel, logvar_sel)
return sample, mu_sel, logvar_sel.exp().sqrt()
def loglikelihood_gmm(self, x, mu, logvar, pi):
# init
loglike = 0
# for all data channels
for n in range(x.shape[1]):
# likelihood of a single mixture at evaluation point
assert not torch.isnan(mu[:, n, :]).any()
assert not torch.isnan(logvar[:, n, :]).any()
pred_dist = tdist.Normal(mu[:, n, :], logvar[:, n, :].exp().sqrt())
x_mod = torch.mm(x[:, n].unsqueeze(1), torch.ones(1, self.n_mixtures, device=self.device))
like = pred_dist.log_prob(x_mod)
# weighting by probability of mixture and summing
temp = (pi[:, n, :] * like)
temp = temp.sum()
# log-likelihood added to previous log-likelihoods
loglike = loglike + temp
return loglike
@staticmethod
def kld_gauss(mu_q, logvar_q, mu_p, logvar_p):
# Goal: Minimize KL divergence between q_pi(z|xi) || p(z|xi)
# This is equivalent to maximizing the ELBO: - D_KL(q_phi(z|xi) || p(z)) + Reconstruction term
# This is equivalent to minimizing D_KL(q_phi(z|xi) || p(z))
term1 = logvar_p - logvar_q - 1
term2 = (torch.exp(logvar_q) + (mu_q - mu_p) ** 2) / torch.exp(logvar_p)
kld = 0.5 * torch.sum(term1 + term2)
return kld
Here is the weight and bias values of decoder network
log_var weight: Parameter containing:
tensor([[-0.0851, -0.1546, -0.1078, ..., 0.1609, -0.1874, -0.1871],
[ 0.1567, 0.0967, -0.1393, ..., 0.0806, 0.1113, -0.1480],
[-0.0308, 0.0223, -0.0731, ..., -0.1336, 0.1145, 0.0562],
...,
[ 0.0786, 0.0827, -0.0301, ..., -0.0144, -0.1381, -0.0895],
[ 0.0378, -0.1866, -0.1844, ..., -0.1335, 0.0971, -0.1709],
[ 0.0795, 0.0443, -0.1000, ..., -0.1163, 0.1738, -0.0816]],
device='cuda:0', requires_grad=True)
log_var bias: Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
device='cuda:0', requires_grad=True)
Train Epoch: [ 0/ 500], Batch [ 1/ 15 ( 7%)] Learning rate: 1.00e-03 Loss: nan
mean weight: Parameter containing:
tensor([[ 0.1852, -0.0739, -0.0665, ..., -0.1058, -0.0030, 0.0329],
[-0.1139, -0.0922, 0.0124, ..., 0.1932, -0.1466, -0.0540],
[ 0.0463, -0.1701, 0.1041, ..., -0.0501, -0.1020, 0.0691],
...,
[ 0.1905, 0.0332, 0.0530, ..., -0.0031, 0.0217, -0.1188],
[-0.1303, -0.0782, -0.0470, ..., -0.1047, -0.0840, -0.1357],
[-0.0184, -0.0348, 0.1574, ..., -0.1103, 0.1969, 0.1092]],
device='cuda:0', requires_grad=True)
mean bias: Parameter containing:
tensor([ 0.0100, 0.0100, 0.0100, 0.0100, -0.0100, 0.0100, -0.0100, 0.0100,
0.0100, 0.0100, 0.0100, 0.0100, 0.0100, 0.0100, -0.0100, 0.0100,
0.0100, -0.0100, -0.0099, 0.0100, 0.0100, -0.0100, -0.0100, -0.0100,
...,
-0.0100, 0.0100, 0.0100, 0.0100, -0.0100, 0.0099, 0.0100, -0.0100,
0.0100, -0.0100, -0.0100, -0.0100, -0.0100, -0.0100, -0.0100, -0.0100,
-0.0100, -0.0100, -0.0100, -0.0100, 0.0100, -0.0100, 0.0097, 0.0100,
0.0100, 0.0100, 0.0100, -0.0100, 0.0100, 0.0100], device='cuda:0',
requires_grad=True)
log_var weight: Parameter containing:
tensor([[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
[ nan, nan, nan, ..., nan, nan, nan],
...,
[ nan, nan, nan, ..., nan, nan, nan],
[ 0.0478, -0.1766, -0.1744, ..., -0.1235, 0.0907, -0.1609],
[ nan, nan, nan, ..., nan, nan, nan]],
device='cuda:0', requires_grad=True)
log_var bias: Parameter containing:
tensor([ nan, nan, nan, nan, nan, nan, nan, 0.0100,
nan, nan, nan, nan, nan, nan, 0.0100, nan,
...,
nan, nan, nan, nan, nan, nan, 0.0100, nan,
nan, nan, nan, nan, 0.0100, nan], device='cuda:0',
requires_grad=True)
The error message is
334 for n in range(x.shape[1]):
335 # likelihood of a single mixture at evaluation point
--> 336 assert not torch.isnan(mu[:, n, :]).any()
337 assert not torch.isnan(logvar[:, n, :]).any()
338 pred_dist = tdist.Normal(mu[:, n, :], logvar[:, n, :].exp().sqrt())
AssertionError:
Any suggestion about how to avoid getting NAN values during training the model?