So How do I get rid off “Channel” dimension in my case
def __init__(self):
super(Dir_VAE, self).__init__()
self.encoder = nn.Sequential(
# input is (nc) x 28 x 28
nn.Conv2d(nc, ndf, 4,4, 0,bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 14 x 14
nn.Conv2d(ndf, ndf * 2, 4,4, 0,bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 7 x 7
nn.Conv2d(ndf * 2, ndf * 4, 4, 4, 0,bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 4 x 4
nn.Conv2d(ndf * 4, 512, 4, 2,0,bias=False),
# nn.BatchNorm2d(1024),
nn.LeakyReLU(0.2, inplace=True),
# nn.Sigmoid()
)
self.decoder = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(512, ngf * 4, 4, 2, 0, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 4, 0, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 2, ngf * 2, 4, 4, 0, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, nc, 4, 4, 0, bias=False),
# nn.BatchNorm2d(ngf),
# nn.ReLU(True),
# state size. (ngf) x 32 x 32
# nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
# nn.Tanh()
nn.Sigmoid()
# state size. (nc) x 64 x 64
)
self.fc1 = nn.Linear(512, 256)
self.fc21 = nn.Linear(256, 10)
self.fc22 = nn.Linear(256, 10)
self.fc3 = nn.Linear(10, 256)
self.fc4 = nn.Linear(256, 512)
self.lrelu = nn.LeakyReLU()
self.relu = nn.ReLU()
# Dir prior
self.prior_mean, self.prior_var = map(nn.Parameter, prior(10, 0.3)) # 0.3 is a hyper param of Dirichlet distribution
self.prior_logvar = nn.Parameter(self.prior_var.log())
self.prior_mean.requires_grad = False
self.prior_var.requires_grad = False
self.prior_logvar.requires_grad = False
def encode(self, x):
conv = self.encoder(x);
print('Size', conv.shape)
h1 = self.fc1(conv.view(-1, 512))
return self.fc21(h1), self.fc22(h1)
def decode(self, gauss_z):
dir_z = F.softmax(gauss_z,dim=1)
# This variable (z) can be treated as a variable that follows a Dirichlet distribution (a variable that can be interpreted as a probability that the sum is 1)
# Use the Softmax function to satisfy the simplex constraint
# シンプレックス制約を満たすようにソフトマックス関数を使用
h3 = self.relu(self.fc3(dir_z))
deconv_input = self.fc4(h3)
print('Deconv ', deconv_input.shape)
deconv_input = deconv_input.view(-1,512,1,1)
return self.decoder(deconv_input)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def forward(self, x):
mu, logvar = self.encode(x)
gauss_z = self.reparameterize(mu, logvar)
# gause_z is a variable that follows a multivariate normal distribution
# Inputting gause_z into softmax func yields a random variable that follows a Dirichlet distribution (Softmax func are used in decoder)
dir_z = F.softmax(gauss_z,dim=1) # This variable follows a Dirichlet distribution
return self.decode(gauss_z), mu, logvar, gauss_z, dir_z
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(self, recon_x, x, mu, logvar, K):
print('Recon ',recon_x.shape)
print('Data ' ,x.shape)
BCE = F.binary_cross_entropy(recon_x.view(-1, 65536), x.view(-1, 65536), reduction='sum')
# ディリクレ事前分布と変分事後分布とのKLを計算
# Calculating KL with Dirichlet prior and variational posterior distributions
# Original paper:"Autoencodeing variational inference for topic model"-https://arxiv.org/pdf/1703.01488
''' prior_mean = self.prior_mean.expand_as(mu)
prior_var = self.prior_var.expand_as(logvar)
prior_logvar = self.prior_logvar.expand_as(logvar)
var_division = logvar.exp() / prior_var # Σ_0 / Σ_1
diff = mu - prior_mean # μ_1 - μ_0
diff_term = diff *diff / prior_var # (μ_1 - μ_0)(μ_1 - μ_0)/Σ_1
logvar_division = prior_logvar - logvar # log|Σ_1| - log|Σ_0| = log(|Σ_1|/|Σ_2|)
# KL
KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(1) - K) '''
KLD = -0.5 * torch.sum(1+logvar - mu**2 - torch.exp(logvar), axis=1)
return BCE + KLD
Deconv torch.Size([1, 512])
Recon torch.Size([1, 1, 256, 256])
Data torch.Size([1, 1, 256, 256])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-33-5c58cc28c4e5> in <module>
266 # 学習(Train)
267 for epoch in range(1, 10):
--> 268 train(epoch)
269 test(epoch)
270 with torch.no_grad():
<ipython-input-33-5c58cc28c4e5> in train(epoch)
226 optimizer.zero_grad()
227 recon_batch, mu, logvar, gauss_z, dir_z = model(data)
--> 228 loss = model.loss_function(recon_batch, data, mu, logvar, 10)
229 loss = loss.mean()
230 loss.backward()
<ipython-input-33-5c58cc28c4e5> in loss_function(self, recon_x, x, mu, logvar, K)
198 print('Recon ',recon_x.shape)
199 print('Data ' ,x.shape)
--> 200 BCE = F.binary_cross_entropy(recon_x.view(-1, 65536), x.view(-1, 65536), reduction='sum')
201 # ディリクレ事前分布と変分事後分布とのKLを計算
202 # Calculating KL with Dirichlet prior and variational posterior distributions
C:\Conda5\lib\site-packages\torch\nn\functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction)
2760 weight = weight.expand(new_size)
2761
-> 2762 return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
2763
2764
RuntimeError: all elements of input should be between 0 and 1```