class NICE(pl.LightningModule):
def __init__(self, in_features=784, hidden_features=1000, num_coupling=4):
super().__init__()
self.save_hyperparameters()
self.layers = nn.ModuleList()
self.num_coupling = num_coupling
for _ in range(self.num_coupling):
self.layers.append(nn.Sequential(
nn.Linear(in_features // 2, hidden_features), nn.ReLU(),
nn.Linear(hidden_features, hidden_features), nn.ReLU(),
nn.Linear(hidden_features, hidden_features), nn.ReLU(),
nn.Linear(hidden_features, hidden_features), nn.ReLU(),
nn.Linear(hidden_features, hidden_features), nn.ReLU(),
nn.Linear(hidden_features, in_features // 2),))
self.scale = nn.Parameter(torch.zeros(in_features))
# logistic distribution for prior
base = torch.distributions.uniform.Uniform(torch.tensor(0.0).to('cuda'), torch.tensor(1.0).to('cuda'))
transforms = [torch.distributions.transforms.SigmoidTransform().inv,
torch.distributions.transforms.AffineTransform(torch.tensor(0.0).to('cuda'),
torch.tensor(1.0).to('cuda'))]
self.prior = torch.distributions.TransformedDistribution(base, transforms)
def forward(self, x):
# x: (batch_size, 784)
x = x.view(x.size(0), -1)
z, log_det = self.forward_(x)
log_prob = self.prior.log_prob(z).sum(dim=1) + log_det
return log_prob
def forward_(self, x):
# x: (batch_size, 784)
z = x
log_det = 0
s = torch.exp(self.scale) # Positive scale
for i in range(self.num_coupling):
# Check the number of layers in self.layers
# print(len(self.layers)) # 4
z1 = z.chunk(2, dim=1)[0] if (i % 2 == 0) else z.chunk(2, dim=1)[1]
z2 = z.chunk(2, dim=1)[1] if (i % 2 == 0) else z.chunk(2, dim=1)[0]
z2 = z2 + self.layers[i](z1)
z = torch.cat([z1, z2], dim=1) if (i % 2 == 0) else torch.cat([z2, z1], dim=1)
z = z * s
log_det += torch.log(s).sum()
return z, log_det
def inverse(self, z):
# z: (batch_size, 784)
z = z.view(z.size(0), -1)
x = self.inverse_(z)
x = x.view(x.size(0), 1, 28, 28)
return x
def inverse_(self, z):
# z: (batch_size, 784)
x = z
s = torch.exp(self.scale).to('cuda')
x = x / s
# print(x.is_cuda) # True
for i in range(self.num_coupling -1, -1, -1):
# Assumed that the num_coupling is even
x1 = x.chunk(2, dim=1)[0] if (i % 2 == 1) else x.chunk(2, dim=1)[1]
x2 = x.chunk(2, dim=1)[1] if (i % 2 == 1) else x.chunk(2, dim=1)[0]
x1 = x1 - self.layers[i](x2)
x = torch.cat([x1, x2], dim=1) if (i % 2 == 1) else torch.cat([x2, x1], dim=1)
return x
def sample(self, img_shape):
z = self.prior.sample(img_shape).to('cuda')
x = self.inverse(z)
return x
========================
It works for training, but error occurs when sampling:
import matplotlib.pyplot as plt
samples = model.sample(img_shape=[16,1,28,28])
samples = samples.view(16, 28, 28).cpu().numpy()
plt.figure(figsize=(10,10))
for i in range(16):
plt.subplot(4,4,i+1)
plt.imshow(samples[i], cmap=“gray”)
plt.axis(“off”)
plt.show()