Thanks for your reply. But the problem is I do not know how I can feed the parameters of the trained CNN model into the encoder of VAE as its input.

Here is the CNN model:

```
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
```

and the VAE model:

```
class VAE(nn.Module):
def __init__(self, image_channels=1, h_dim=1024, z_dim=32):
super(VAE, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(image_channels, 32, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=4, stride=2),
nn.ReLU(),
Flatten()
)
self.fc1 = nn.Linear(h_dim, z_dim)
self.fc2 = nn.Linear(h_dim, z_dim)
self.fc3 = nn.Linear(z_dim, h_dim)
self.decoder = nn.Sequential(
UnFlatten(),
nn.ConvTranspose2d(h_dim, 128, kernel_size=5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),
nn.Sigmoid(),
)
def reparameterize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
# return torch.normal(mu, std)
esp = torch.randn(*mu.size())
z = mu + std * esp
return z
def bottleneck(self, h):
mu, logvar = self.fc1(h), self.fc2(h)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
def encode(self, x):
h = self.encoder(x)
z, mu, logvar = self.bottleneck(h)
return z, mu, logvar
def decode(self, z):
z = self.fc3(z)
z = self.decoder(z)
return z
def forward(self, x):
z, mu, logvar = self.encode(x)
z = self.decode(z)
return z, mu, logvar
```