Hello there,
I am trying to build a ConditionalVAE by myself. I am going to use it to reconstruct images of CIFAR100.
I followed the code in PyTorch-VAE/cvae.py at master · AntixK/PyTorch-VAE · GitHub and change the parameters to make it for a larger resolution (224 * 224). The framework of my CVAE is listed below with hidden_dims = [32, 64, 128, 256]
.
The CVAE I used is ConditionalVAE(3,100,256, [32, 64, 128, 256], 224)
def __init__(self,
in_channels: int,
num_classes: int,
latent_dim: int,
hidden_dims: List = None,
img_size: int = 64,
**kwargs) -> None:
super(ConditionalVAE, self).__init__()
self._kernel_size = 5
self._stride = 3
self._padding = 1
self.latent_dim = latent_dim
self.img_size = img_size
self.num_classes = num_classes
self.embed_class = nn.Linear(num_classes, img_size * img_size)
self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1)
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256]
self.hidden_dims = hidden_dims.copy()
self.n = img_size
output_padding = []
for i in range(len(hidden_dims)):
tmp = (self.n + 2 * self._padding - self._kernel_size + self._stride)
self.n = tmp // self._stride
output_padding.append(tmp - self._stride * self.n)
output_padding.reverse()
in_channels += 1 # To account for the extra label channel
# Build Encoder
prev_dim = in_channels
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(prev_dim, out_channels=h_dim,
kernel_size=self._kernel_size, stride=self._stride, padding=self._padding),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
prev_dim = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1] * self.n * self.n, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1] * self.n * self.n, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim + num_classes, hidden_dims[-1] * self.n * self.n)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._padding,
output_padding=output_padding[i]),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._padding,
output_padding=output_padding[-1]),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels=in_channels - 1,
kernel_size=self._kernel_size, padding=(self._kernel_size - 1) // 2),
nn.Tanh()
)
When I tried to reconstruct images by this model, I found that the image reconstructed has gray blocks in the right and bottom part of it (for simplicity, I retrieved very early version of CVAE, so the performance is bad ), like
How could I get rid of this? Thanks in advance.
The whole class of my CVAE is listed below:
class ConditionalVAE(BaseVAE):
def __init__(self,
in_channels: int,
num_classes: int,
latent_dim: int,
hidden_dims: List = None,
img_size: int = 64,
**kwargs) -> None:
super(ConditionalVAE, self).__init__()
self._kernel_size = 5
self._stride = 3
self._padding = 1
self.latent_dim = latent_dim
self.img_size = img_size
self.num_classes = num_classes
self.embed_class = nn.Linear(num_classes, img_size * img_size)
self.embed_data = nn.Conv2d(in_channels, in_channels, kernel_size=1)
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256]
self.hidden_dims = hidden_dims.copy()
self.n = img_size
output_padding = []
for i in range(len(hidden_dims)):
tmp = (self.n + 2 * self._padding - self._kernel_size + self._stride)
self.n = tmp // self._stride
output_padding.append(tmp - self._stride * self.n)
output_padding.reverse()
in_channels += 1 # To account for the extra label channel
# Build Encoder
prev_dim = in_channels
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(prev_dim, out_channels=h_dim,
kernel_size=self._kernel_size, stride=self._stride, padding=self._padding),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
prev_dim = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1] * self.n * self.n, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1] * self.n * self.n, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim + num_classes, hidden_dims[-1] * self.n * self.n)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._padding,
output_padding=output_padding[i]),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._padding,
output_padding=output_padding[-1]),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels=in_channels - 1,
kernel_size=self._kernel_size, padding=(self._kernel_size - 1) // 2),
nn.Tanh()
)
def encode(self, inputs: Tensor) -> List[Tensor]:
"""
Encodes the inputs by passing through the encoder network
and returns the latent codes.
:param inputs: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(inputs)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
result = self.decoder_input(z)
result = result.view(-1, self.hidden_dims[-1], self.n, self.n)
result = self.decoder(result)
result = self.final_layer(result)
return result
def reparameterize(self, mu: Tensor, log_var: Tensor) -> Tensor:
"""
Will a single z be enough ti compute the expectation
for the loss??
:param mu: (Tensor) Mean of the latent Gaussian
:param log_var: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, inputs: Tensor, **kwargs) -> List[Tensor]:
y = self.one_hot(kwargs['labels'])
embedded_class = self.embed_class(y)
embedded_class = embedded_class.view(-1, self.img_size, self.img_size).unsqueeze(1)
embedded_input = self.embed_data(inputs)
x = torch.cat([embedded_input, embedded_class], dim=1)
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
z = torch.cat([z, y], dim=1)
return [self.decode(z), inputs, mu, log_var]
def sample(self,
num_samples: int,
current_device: int,
**kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
y = self.one_hot(kwargs['labels'])
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
z = torch.cat([z, y], dim=1)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an inputs image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x, **kwargs)[0]
def one_hot(self, labels: Tensor) -> Tensor:
device = labels.device
targets = torch.zeros(labels.size(0), self.num_classes).to(device)
idx = torch.arange(labels.size(0)).to(device)
targets[idx, labels] = 1
return targets.float()