VAE Training for Custom Images

My images are of size 600x800. I found a VAE code online. I would like to try it on my own images (800 total images 160 of which are val images).

import matplotlib.pyplot as plt
import seaborn as sns
import torch
import os
from skimage import io, transform
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
batch_size = 8
epochs = 50
no_cuda = False
seed = 1
log_interval = 50

cuda = not no_cuda and torch.cuda.is_available()

torch.manual_seed(seed)
device = torch.device("cuda" if cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
print('device is {} and kwargs is {}'.format(device, kwargs))
train_root = 'labeled-data/train_moth'
val_root = 'labeled-data/val_moth'
train_loader_food = torch.utils.data.DataLoader(
    datasets.ImageFolder(train_root, transform=transforms.ToTensor()),
    batch_size = batch_size, shuffle=True, **kwargs)

val_loader_food = torch.utils.data.DataLoader(
    datasets.ImageFolder(val_root, transform=transforms.ToTensor()),
    batch_size = batch_size, shuffle=True, **kwargs)
class VAE_CNN(nn.Module):
    def __init__(self):
        super(VAE_CNN, self).__init__()

        # Encoder
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(16)

        # Latent vectors mu and sigma
        self.fc1 = nn.Linear(25 * 25 * 16, 2048)
        self.fc_bn1 = nn.BatchNorm1d(2048)
        self.fc21 = nn.Linear(2048, 2048)
        self.fc22 = nn.Linear(2048, 2048)

        # Sampling vector
        self.fc3 = nn.Linear(2048, 2048)
        self.fc_bn3 = nn.BatchNorm1d(2048)
        self.fc4 = nn.Linear(2048, 25 * 25 * 16)
        self.fc_bn4 = nn.BatchNorm1d(25 * 25 * 16)

        # Decoder
        self.conv5 = nn.ConvTranspose2d(16, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn5 = nn.BatchNorm2d(64)
        self.conv6 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn6 = nn.BatchNorm2d(32)
        self.conv7 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn7 = nn.BatchNorm2d(16)
        self.conv8 = nn.ConvTranspose2d(16, 3, kernel_size=3, stride=1, padding=1, bias=False)

        self.relu = nn.ReLU()

    def encode(self, x):
        conv1 = self.relu(self.bn1(self.conv1(x)))
        conv2 = self.relu(self.bn2(self.conv2(conv1)))
        conv3 = self.relu(self.bn3(self.conv3(conv2)))
        conv4 = self.relu(self.bn4(self.conv4(conv3))).view(-1, 25 * 25 * 16)

        fc1 = self.relu(self.fc_bn1(self.fc1(conv4)))

        r1 = self.fc21(fc1)
        r2 = self.fc22(fc1)
        
        return r1, r2

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        fc3 = self.relu(self.fc_bn3(self.fc3(z)))
        fc4 = self.relu(self.fc_bn4(self.fc4(fc3))).view(-1, 16, 25, 25)

        conv5 = self.relu(self.bn5(self.conv5(fc4)))
        conv6 = self.relu(self.bn6(self.conv6(conv5)))
        conv7 = self.relu(self.bn7(self.conv7(conv6)))
        return self.conv8(conv7).view(-1, 3, 100, 100)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
class customLoss(nn.Module):
    def __init__(self):
        super(customLoss, self).__init__()
        self.mse_loss = nn.MSELoss(reduction="sum")

    def forward(self, x_recon, x, mu, logvar):
        loss_MSE = self.mse_loss(x_recon, x)
        loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return loss_MSE + loss_KLD
model = VAE_CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_mse = customLoss()
val_losses = []
train_losses = []
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader_food):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_mse(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader_food.dataset),
                       100. * batch_idx / len(train_loader_food),
                       loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader_food.dataset)))
    train_losses.append(train_loss / len(train_loader_food.dataset))
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(val_loader_food):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_mse(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                        recon_batch.view(batch_size, 3, 100, 100)[:n]])
                save_image(comparison.cpu(),
                           'VAE_results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(val_loader_food.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    val_losses.append(test_loss)
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        sample = torch.randn(2, 2048).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(2, 3, 100, 100),
                   'VAE_results/sample_' + str(epoch) + '.png')

The error is:

/home/mona/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py:445: UserWarning: Using a target size (torch.Size([8, 3, 600, 800])) that is different to the input size (torch.Size([384, 3, 100, 100])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-15-37f467c4f834> in <module>
      1 for epoch in range(1, epochs + 1):
----> 2     train(epoch)
      3     test(epoch)
      4     with torch.no_grad():
      5         sample = torch.randn(2, 2048).to(device)

<ipython-input-13-8f191bde6513> in train(epoch)
      6         optimizer.zero_grad()
      7         recon_batch, mu, logvar = model(data)
----> 8         loss = loss_mse(recon_batch, data, mu, logvar)
      9         loss.backward()
     10         train_loss += loss.item()

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

<ipython-input-9-6c49edf3f96a> in forward(self, x_recon, x, mu, logvar)
      5 
      6     def forward(self, x_recon, x, mu, logvar):
----> 7         loss_MSE = self.mse_loss(x_recon, x)
      8         loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
      9 

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    443 
    444     def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 445         return F.mse_loss(input, target, reduction=self.reduction)
    446 
    447 

~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in mse_loss(input, target, size_average, reduce, reduction)
   2645             ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
   2646     else:
-> 2647         expanded_input, expanded_target = torch.broadcast_tensors(input, target)
   2648         ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
   2649     return ret

~/anaconda3/lib/python3.7/site-packages/torch/functional.py in broadcast_tensors(*tensors)
     63         if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
     64             return handle_torch_function(broadcast_tensors, tensors, *tensors)
---> 65     return _VF.broadcast_tensors(tensors)
     66 
     67 

RuntimeError: The size of tensor a (100) must match the size of tensor b (800) at non-singleton dimension 3

Here’s the link to tutorial:

Would it make sense to use something like the following or would it be better to change network such that it would work with my 600x800 images?
Mostly confused since my images are not xx dimension but xy dimension and I am forcing the dimension to be 100x100 to work with the network

my_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((100,100))
                      ])

train_loader_food = torch.utils.data.DataLoader(
    datasets.ImageFolder(train_root, transform = my_transform),
    batch_size = batch_size, shuffle=True, **kwargs)

val_loader_food = torch.utils.data.DataLoader(
    datasets.ImageFolder(val_root, transform = my_transform),
    batch_size = batch_size, shuffle=True, **kwargs)