Reducing and Reconstruction CNN model parameters using a VAE

Hi there,

Suppose I have a simple CNN model with 2 Conv2D layers, I trained this model on my image dataset, I am going to feed the parameters of this CNN model into a VAE (as input of encoder) to first reduce their parameters into an embedding space (Z or latent space of VAE). Then, I want to reconstruct the CNN parameters (with their original dimensions) using the output of the decoder of VAE.

I do not know how can I implement this in PyTorch and how would I reconstruct back the vector of parameters to the CNN model parameters.

Thanks in advance!

1 Like

you can try with some tutorial here Variational Autoencoder (VAE) — PyTorch Tutorial | by Reza Kalantar | Nov, 2022 | Medium | Medium or some other tutorial first, from your explanation above its already in correct way, just dont forget about reparameterization before go to decoding

Thanks for your reply. But the problem is I do not know how I can feed the parameters of the trained CNN model into the encoder of VAE as its input.

Here is the CNN model:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

and the VAE model:

class VAE(nn.Module):
    def __init__(self, image_channels=1, h_dim=1024, z_dim=32):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2),
            nn.ReLU(),
            Flatten()
        )
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        self.decoder = nn.Sequential(
            UnFlatten(),
            nn.ConvTranspose2d(h_dim, 128, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),
            nn.Sigmoid(),
        )
        
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        # return torch.normal(mu, std)
        esp = torch.randn(*mu.size())
        z = mu + std * esp
        return z
    
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

    def encode(self, x):
        h = self.encoder(x)
        z, mu, logvar = self.bottleneck(h)
        return z, mu, logvar

    def decode(self, z):
        z = self.fc3(z)
        z = self.decoder(z)
        return z

    def forward(self, x):
        z, mu, logvar = self.encode(x)
        z = self.decode(z)
        return z, mu, logvar

@ptrblck, I would be happy to have your comments. Thanks

Unsure if I understand your use case correctly, but this minimal code snippet might work:

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
        self.conv2 = nn.Conv2d(6, 1, 3, 1, 1)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return x
    
class VAE(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, 1)
        self.fc2 = nn.Linear(1, in_features)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
cnn = CNN()
flat_weights = torch.cat([p.view(-1) for p in cnn.parameters()]).unsqueeze(1)
in_features = flat_weights.size(1)
vae = VAE(in_features)

out = vae(flat_weights)
out.mean().backward()

print([p.grad.abs().sum() for p in cnn.parameters()])
# [tensor(0.0165), tensor(0.0006), tensor(0.0055), tensor(0.0001)]

Extract CNN parameters, encode them with VAE’s encoder to latent space, then decode to reconstruct parameters for the CNN model. Apply these parameters.