Facing issue with ACGAN: RuntimeError: 0D or 1D target tensor expected, multi-target not supported

I am trying to reimplement ACGAN code from ACGAN_Chromos/acgan128.py at master · jvirico/ACGAN_Chromos · GitHub for X-ray images.

I have modified the Custom dataset class according to my own need. Here is modified code:

import argparse
import os
import numpy as np
import math
import matplotlib.pyplot as plt
import cv2
import glob
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.datasets.folder import pil_loader
from torchvision.datasets.utils import list_dir, list_files
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of        gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_args()

## Hyperparameter customization 
#  (overrides command line arguments, 
#  will be removed at the end)
#####################################
opt.n_epochs = 1
opt.batch_size = 64
# Adam Optimizer
opt.lr = 0.0002
opt.b1 = 0.5
opt.b2 = 0.999
#
opt.n_cpu = 8
#
opt.latent_dim = 100
opt.n_classes = 2
opt.img_size = 128
opt.channels = 3
opt.sample_interval = 400

# Results
save_ckp_every = 10 #epochs
results_folder = 'results/model'+str(opt.img_size)+'_ep'+ str(opt.n_epochs)+'_bs'+ str(opt.batch_size)
ckp_folder = results_folder + '/checkpoints'
#####################################
## Dataset manipulation
#####################################
transform0 = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
 ])

  ## Data Transformation to apply
  transform = transform0
  #####################################

   os.makedirs(ckp_folder, exist_ok=True)
   os.makedirs(results_folder + "/images", exist_ok=True)
   os.makedirs(results_folder + "/plots", exist_ok=True)
   print(opt)


class CustomDataset(Dataset):
    def __init__(self):
        self.imgs_path = "Un-norm/"
        file_list = glob.glob(self.imgs_path + "*")
        print(file_list)
        self.data = []
        for class_path in file_list:
            class_name = class_path.split("/")[-1]
            for img_path in glob.glob(class_path + "/*.jpg"):
                self.data.append([img_path, class_name])
        print(self.data)
        self.class_map = {"NORMAL" : 0, "PNEUMONIA": 1}
        self.img_dim = (128, 128)    
    def __len__(self):
        return len(self.data)    
    def __getitem__(self, idx):
        img_path, class_name = self.data[idx]
        img = cv2.imread(img_path)
        img = cv2.resize(img, self.img_dim)
        img = (img-127.5) / 127.5
        class_id = self.class_map[class_name]
        img_tensor = torch.from_numpy(img)
        img_tensor = img_tensor.permute(2, 0, 1)
        class_id = torch.tensor([class_id])
        return img_tensor, class_id


train_set = CustomDataset()
dataloader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, drop_last=False, num_workers=2, pin_memory=True)

cuda = True if torch.cuda.is_available() else False


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)

        self.init_size = opt.img_size // 4  # Initial size before upsampling
        self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            #nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, opt.channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, noise, labels):
        gen_input = torch.mul(self.label_emb(labels), noise)
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            """Returns layers of each discriminator block"""
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.conv_blocks = nn.Sequential(
            *discriminator_block(opt.channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = opt.img_size // 2 ** 4

    # Output layers
    self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
    self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())

    def forward(self, img):
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        label = self.aux_layer(out)

        return validity, label


# Loss functions
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    auxiliary_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)


# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor


def sample_image(n_row, batches_done):
    """Saves a grid of generated images ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, results_folder + "/images/%d.png" % batches_done, nrow=n_row, normalize=True)


# ----------
#  Training
# ----------

losses = []
accuracies = []
iteration_checkpoints = []

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

    # Configure input
    real_imgs = Variable(imgs.type(FloatTensor))
    labels = Variable(labels.type(LongTensor))

    # -----------------
    #  Train Generator
    # -----------------

    optimizer_G.zero_grad()

    # Sample noise and labels as generator input
    z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
    gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))

    # Generate a batch of images
    gen_imgs = generator(z, gen_labels)

    # Loss measures generator's ability to fool the discriminator
    validity, pred_label = discriminator(gen_imgs)
    g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))

    g_loss.backward()
    optimizer_G.step()
    
    # ---------------------
    #  Train Discriminator
    # ---------------------

    optimizer_D.zero_grad()

    # Loss for real images
    real_pred, real_aux = discriminator(real_imgs)
    d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2

    # Loss for fake images
    fake_pred, fake_aux = discriminator(gen_imgs.detach())
    d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2

    # Total discriminator loss
    d_loss = (d_real_loss + d_fake_loss) / 2

    # Calculate discriminator accuracy
    pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
    gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)
    d_acc = np.mean(np.argmax(pred, axis=1) == gt)

    d_loss.backward()
    optimizer_D.step()

    print(
        "[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"
        % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item())
    )
    batches_done = epoch * len(dataloader) + i
    if batches_done % opt.sample_interval == 0:
        sample_image(n_row=10, batches_done=batches_done)
    
    # Saving losses and accuracy
    losses.append((d_loss, g_loss))
    accuracies.append(100.0 * d_acc)
    iteration_checkpoints.append(epoch + 1)

    # save models every 10 epochs
    if (epoch + 1 ) % opt.n_epochs == 0 or ((epoch+1) < opt.n_epochs and (epoch+1) % save_ckp_every == 0):
        torch.save(generator.state_dict(), ckp_folder+'/G_{0}.pt'.format(epoch+1))
        torch.save(discriminator.state_dict(), ckp_folder+'/D_{0}.pt'.format(epoch+1))


losses = np.array(losses)

# Plot training losses for Discriminator and Generator
plt.figure(figsize=(15, 5))
plt.plot(iteration_checkpoints, losses.T[0], label="Discriminator loss")
plt.plot(iteration_checkpoints, losses.T[1], label="Generator loss")

plt.xticks(iteration_checkpoints, rotation=90)

plt.title("Training Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.legend()
plt.savefig(results_folder + "/plots/Losses.png")

accuracies = np.array(accuracies)

# Plot Discriminator accuracy
plt.figure(figsize=(15, 5))
plt.plot(iteration_checkpoints, accuracies, label="Discriminator accuracy")

plt.xticks(iteration_checkpoints, rotation=90)
plt.yticks(range(0, 100, 5))

plt.title("Discriminator Accuracy")
plt.xlabel("Iteration")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.savefig(results_folder + "/plots/Accuracy.png")
plt.show()

I am facing this error using Linux on GPU clusters.

The error log is:

Loading opencv3-py37-cuda10.1-gcc/3.4.11
  Loading requirement: openblas/dynamic/0.2.20 hdf5_18/1.8.20
    cuda10.1/toolkit/10.1.243 gcc5/5.5.0 python37
    ml-pythondeps-py37-cuda10.1-gcc/4.1.2
Loading tensorflow2-py37-cuda10.1-gcc/2.2.0
  Loading requirement: cudnn7.6-cuda10.1/7.6.5.32 keras-py37-cuda10.1-gcc/2.3.1
    protobuf3-gcc/3.8.0 nccl2-cuda10.1-gcc/2.7.8
Currently Loaded Modulefiles:
 1) gcc/8.2.0                   9) ml-pythondeps-py37-cuda10.1-gcc/4.1.2
 2) slurm/18.08.9              10) opencv3-py37-cuda10.1-gcc/3.4.11
 3) shared                     11) cudnn7.6-cuda10.1/7.6.5.32
 4) openblas/dynamic/0.2.20    12) keras-py37-cuda10.1-gcc/2.3.1
 5) hdf5_18/1.8.20             13) protobuf3-gcc/3.8.0
 6) cuda10.1/toolkit/10.1.243  14) nccl2-cuda10.1-gcc/2.7.8
 7) gcc5/5.5.0                 15) tensorflow2-py37-cuda10.1-gcc/2.2.0
 8) python37
/home/r00206978/.local/lib/python3.7/site-packages/torch/nn/modules/container.py:141: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
      input = module(input)
    Traceback (most recent call last):
      File "acgan_xraypt.py", line 281, in <module>
        d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2
      File "/home/r00206978/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/r00206978/.local/lib/python3.7/site-packages/torch/nn/modules/loss.py", line 1165, in forward
        label_smoothing=self.label_smoothing)
      File "/home/r00206978/.local/lib/python3.7/site-packages/torch/nn/functional.py", line 2996, in cross_entropy
        return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
    RuntimeError: 0D or 1D target tensor expected, multi-target not supported

Can anyone please solve this error and help me where the code modification is required?

My dataset has two classes (N has 1340 images and P has 2680 images) with 128x128 X-ray images. Total images are 4020.

Thanks a million.

The loss calculation in nn.CrossEntropyLoss is failing as the shape is most likely wrong.
If you are using class indices as the target, nn.CrossEntropyLoss expects a model output in the shape of [batch_size, nb_classes, *] containing raw logits and a target in the shape [batch_size, *] (note the missing nb_classes dimension) containig class indices in the range [0, nb_classes-1].
In newer PyTorch releases it can also accept probabilities, but based on your error message I would guess you want to use class indices.

Thank you very much @ptrblck for your attention.
Yes. I am loading the dataset from folder directories considering the name of sub-directories as their labels and inputting to the model in the shape of torch tensor [batch size, channel, width, height]. My dataset has two classes. I am confused that where should I give this class indices information?