My images are of size 600x800. I found a VAE code online. I would like to try it on my own images (800 total images 160 of which are val images).
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import os
from skimage import io, transform
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
batch_size = 8
epochs = 50
no_cuda = False
seed = 1
log_interval = 50
cuda = not no_cuda and torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
print('device is {} and kwargs is {}'.format(device, kwargs))
train_root = 'labeled-data/train_moth'
val_root = 'labeled-data/val_moth'
train_loader_food = torch.utils.data.DataLoader(
datasets.ImageFolder(train_root, transform=transforms.ToTensor()),
batch_size = batch_size, shuffle=True, **kwargs)
val_loader_food = torch.utils.data.DataLoader(
datasets.ImageFolder(val_root, transform=transforms.ToTensor()),
batch_size = batch_size, shuffle=True, **kwargs)
class VAE_CNN(nn.Module):
def __init__(self):
super(VAE_CNN, self).__init__()
# Encoder
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(64)
self.conv4 = nn.Conv2d(64, 16, kernel_size=3, stride=2, padding=1, bias=False)
self.bn4 = nn.BatchNorm2d(16)
# Latent vectors mu and sigma
self.fc1 = nn.Linear(25 * 25 * 16, 2048)
self.fc_bn1 = nn.BatchNorm1d(2048)
self.fc21 = nn.Linear(2048, 2048)
self.fc22 = nn.Linear(2048, 2048)
# Sampling vector
self.fc3 = nn.Linear(2048, 2048)
self.fc_bn3 = nn.BatchNorm1d(2048)
self.fc4 = nn.Linear(2048, 25 * 25 * 16)
self.fc_bn4 = nn.BatchNorm1d(25 * 25 * 16)
# Decoder
self.conv5 = nn.ConvTranspose2d(16, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
self.bn5 = nn.BatchNorm2d(64)
self.conv6 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1, bias=False)
self.bn6 = nn.BatchNorm2d(32)
self.conv7 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
self.bn7 = nn.BatchNorm2d(16)
self.conv8 = nn.ConvTranspose2d(16, 3, kernel_size=3, stride=1, padding=1, bias=False)
self.relu = nn.ReLU()
def encode(self, x):
conv1 = self.relu(self.bn1(self.conv1(x)))
conv2 = self.relu(self.bn2(self.conv2(conv1)))
conv3 = self.relu(self.bn3(self.conv3(conv2)))
conv4 = self.relu(self.bn4(self.conv4(conv3))).view(-1, 25 * 25 * 16)
fc1 = self.relu(self.fc_bn1(self.fc1(conv4)))
r1 = self.fc21(fc1)
r2 = self.fc22(fc1)
return r1, r2
def reparameterize(self, mu, logvar):
if self.training:
std = logvar.mul(0.5).exp_()
eps = Variable(std.data.new(std.size()).normal_())
return eps.mul(std).add_(mu)
else:
return mu
def decode(self, z):
fc3 = self.relu(self.fc_bn3(self.fc3(z)))
fc4 = self.relu(self.fc_bn4(self.fc4(fc3))).view(-1, 16, 25, 25)
conv5 = self.relu(self.bn5(self.conv5(fc4)))
conv6 = self.relu(self.bn6(self.conv6(conv5)))
conv7 = self.relu(self.bn7(self.conv7(conv6)))
return self.conv8(conv7).view(-1, 3, 100, 100)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
class customLoss(nn.Module):
def __init__(self):
super(customLoss, self).__init__()
self.mse_loss = nn.MSELoss(reduction="sum")
def forward(self, x_recon, x, mu, logvar):
loss_MSE = self.mse_loss(x_recon, x)
loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return loss_MSE + loss_KLD
model = VAE_CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_mse = customLoss()
val_losses = []
train_losses = []
def train(epoch):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader_food):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_mse(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader_food.dataset),
100. * batch_idx / len(train_loader_food),
loss.item() / len(data)))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader_food.dataset)))
train_losses.append(train_loss / len(train_loader_food.dataset))
def test(epoch):
model.eval()
test_loss = 0
with torch.no_grad():
for i, (data, _) in enumerate(val_loader_food):
data = data.to(device)
recon_batch, mu, logvar = model(data)
test_loss += loss_mse(recon_batch, data, mu, logvar).item()
if i == 0:
n = min(data.size(0), 8)
comparison = torch.cat([data[:n],
recon_batch.view(batch_size, 3, 100, 100)[:n]])
save_image(comparison.cpu(),
'VAE_results/reconstruction_' + str(epoch) + '.png', nrow=n)
test_loss /= len(val_loader_food.dataset)
print('====> Test set loss: {:.4f}'.format(test_loss))
val_losses.append(test_loss)
for epoch in range(1, epochs + 1):
train(epoch)
test(epoch)
with torch.no_grad():
sample = torch.randn(2, 2048).to(device)
sample = model.decode(sample).cpu()
save_image(sample.view(2, 3, 100, 100),
'VAE_results/sample_' + str(epoch) + '.png')
The error is:
/home/mona/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([8, 3, 600, 800])) that is different to the input size (torch.Size([384, 3, 100, 100])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
return F.mse_loss(input, target, reduction=self.reduction)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-15-37f467c4f834> in <module>
1 for epoch in range(1, epochs + 1):
----> 2 train(epoch)
3 test(epoch)
4 with torch.no_grad():
5 sample = torch.randn(2, 2048).to(device)
<ipython-input-13-8f191bde6513> in train(epoch)
6 optimizer.zero_grad()
7 recon_batch, mu, logvar = model(data)
----> 8 loss = loss_mse(recon_batch, data, mu, logvar)
9 loss.backward()
10 train_loss += loss.item()
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
<ipython-input-9-6c49edf3f96a> in forward(self, x_recon, x, mu, logvar)
5
6 def forward(self, x_recon, x, mu, logvar):
----> 7 loss_MSE = self.mse_loss(x_recon, x)
8 loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
9
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
443
444 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 445 return F.mse_loss(input, target, reduction=self.reduction)
446
447
~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in mse_loss(input, target, size_average, reduce, reduction)
2645 ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
2646 else:
-> 2647 expanded_input, expanded_target = torch.broadcast_tensors(input, target)
2648 ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
2649 return ret
~/anaconda3/lib/python3.7/site-packages/torch/functional.py in broadcast_tensors(*tensors)
63 if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
64 return handle_torch_function(broadcast_tensors, tensors, *tensors)
---> 65 return _VF.broadcast_tensors(tensors)
66
67
RuntimeError: The size of tensor a (100) must match the size of tensor b (800) at non-singleton dimension 3
Here’s the link to tutorial: