image segmentation with cross-entropy loss

I am a new user of Pytorch.
I’d like to use the cross-entropy loss function

number of classes=2
output.shape=[4,2,224,224]
output_min=tensor(-1.9295)]
output_max=tensor(2.6400)]

number of channels=3
target.shape=[4,3,224,224]
targets_max=tensor(-2.1008)]
targets_min=tensor(-2.1179)]

how to evaluate:
loss = criterion(output, target)?
Thanks.

Hello Neo!

As an aside, for a two-class classification problem, you will be
better off treating this explicitly as a binary problem, rather than
as a two-class instance of the more general multi-class problem.
To do so you would use BCEWithLogitsLoss (“Binary Cross
Entropy”), rather than the multi-class CrossEntropyLoss.

But you can certainly treat this as a general multi-class problem,
and I will answer your question in this context.

So your outputs are raw-score logits, rather than probabilities
that lie between 0.0 and 1.0. Good.

I assume that your first dimension, 4, is your batch size. This is
fine.

It probably does not make sense to have a channels dimension in
your target. (If you think it does, you should further explain your
use case.) In any event, as it stands, this target shape won’t match
your output shape.

If your output shape is [nBatch, nClass, height, width],
then (for CrossEntropyLoss) your target shape must be
[nBatch, height, width], with no nClass dimension.

This is wrong (for CrossEntropyLoss). Your target values must
be integer (long) class labels that run from 0 to nClass - 1,
so in your two-class case, that take on the values 0 and 1.

If you could explain a little more where target comes from, and
what the numbers actually mean, we can help sort this out.

Best.

K. Frank

I’m sorry for the delay i had problems
and I’m sorry for my English

my model is:

number of classes=2 or 3 or 10

And the output dimension of the model is [No x Co x Ho x Wo]
where,

No -> is the batch size (same as Ni)
Co -> is the number of classes that the dataset have!
Ho -> the height of the image (which is the same as Hi in almost all cases)
Wo -> the width of the image (which is the same as Wi in almost all cases)

number of channels=1 or 3

the target dimension is [Ni x Ci x Hi x Wi]
where,

Ni -> the batch size
Ci -> the number of channels (which is 3 or 1)
Hi -> the height of the image
Wi -> the width of the image 

and thanks for your reply

I apply train transformation into image and mask:

train_transform = et.ExtCompose([
#et.ExtResize(size=opts.crop_size),
et.ExtRandomScale((0.5, 2.0)),
et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True),
et.ExtRandomHorizontalFlip()
et.ExtToTensor(),
et.ExtNormalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])

mask is rgb or gray with values ​​of 0 background and 1class1 2class2

Hello Neo!

Ho and Wo must be the same as Hi and Wi in all cases, not
just in “almost all” cases.

This might be what you have, but it simply won’t work.
CrossEntropyLoss requires that, for a model output of shape
[No, Co, Ho, Wo], the target have shape [No, Ho, Wo]
(and that the values of the target are integer class labels that
run from 0 to Co - 1).

If your target has this extra “channel” dimension (Ci), it won’t
work (and Hi and Wi must match Ho and Wo, as well). (Just
to be clear, you can’t have the Ci dimension at all, even if
Ci = 1.)

Good luck.

K. Frank

thanks Frank for your reply
I’m going to look for another way

Hey :hugs: sorry to disturb you, just wanted to confirm -

  1. the raw logits are supposed to be one-hot encoded - say as a sample shape of (1, 6, 256, 256) if one has multi-class classification w/ 6 labels
  2. then the target to the loss function has to be to the non-onehot-encoded, the true integer labels in their pristine form.

I am confused why pytorch doesn’t do this implicitly :thinking: though was my understanding correct?

So How do I get rid off “Channel” dimension in my case

    def __init__(self):
        super(Dir_VAE, self).__init__()
        self.encoder = nn.Sequential(
            # input is (nc) x 28 x 28
            nn.Conv2d(nc, ndf, 4,4, 0,bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 14 x 14
            nn.Conv2d(ndf, ndf * 2, 4,4, 0,bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 7 x 7
            nn.Conv2d(ndf * 2, ndf * 4, 4, 4, 0,bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 4 x 4
            nn.Conv2d(ndf * 4, 512, 4, 2,0,bias=False),
            # nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            # nn.Sigmoid()
        )

        self.decoder = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(512, ngf * 4, 4, 2, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 4, 0, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 2, ngf * 2, 4, 4, 0, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, nc, 4, 4, 0, bias=False),
            # nn.BatchNorm2d(ngf),
            # nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            # nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            # nn.Tanh()
            nn.Sigmoid()
            # state size. (nc) x 64 x 64
        )
        self.fc1 = nn.Linear(512, 256)
        self.fc21 = nn.Linear(256, 10)
        self.fc22 = nn.Linear(256, 10)

        self.fc3 = nn.Linear(10, 256)
        self.fc4 = nn.Linear(256, 512)

        self.lrelu = nn.LeakyReLU()
        self.relu = nn.ReLU()

        # Dir prior
        self.prior_mean, self.prior_var = map(nn.Parameter, prior(10, 0.3)) # 0.3 is a hyper param of Dirichlet distribution
        self.prior_logvar = nn.Parameter(self.prior_var.log())
        self.prior_mean.requires_grad = False
        self.prior_var.requires_grad = False
        self.prior_logvar.requires_grad = False


    def encode(self, x):
        conv = self.encoder(x);
        print('Size', conv.shape)
        h1 = self.fc1(conv.view(-1, 512))
        return self.fc21(h1), self.fc22(h1)

    def decode(self, gauss_z):
        dir_z = F.softmax(gauss_z,dim=1) 
        # This variable (z) can be treated as a variable that follows a Dirichlet distribution (a variable that can be interpreted as a probability that the sum is 1)
        # Use the Softmax function to satisfy the simplex constraint
        # シンプレックス制約を満たすようにソフトマックス関数を使用
        h3 = self.relu(self.fc3(dir_z))
        deconv_input = self.fc4(h3)
        print('Deconv ', deconv_input.shape)
        deconv_input = deconv_input.view(-1,512,1,1)
        return self.decoder(deconv_input)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std


    def forward(self, x):
        mu, logvar = self.encode(x)
        gauss_z = self.reparameterize(mu, logvar) 
        # gause_z is a variable that follows a multivariate normal distribution
        # Inputting gause_z into softmax func yields a random variable that follows a Dirichlet distribution (Softmax func are used in decoder)
        dir_z = F.softmax(gauss_z,dim=1) # This variable follows a Dirichlet distribution
        return self.decode(gauss_z), mu, logvar, gauss_z, dir_z

    # Reconstruction + KL divergence losses summed over all elements and batch
    def loss_function(self, recon_x, x, mu, logvar, K):
    
        print('Recon ',recon_x.shape)
        print('Data ' ,x.shape)
        BCE = F.binary_cross_entropy(recon_x.view(-1, 65536), x.view(-1, 65536), reduction='sum')
        # ディリクレ事前分布と変分事後分布とのKLを計算
        # Calculating KL with Dirichlet prior and variational posterior distributions
        # Original paper:"Autoencodeing variational inference for topic model"-https://arxiv.org/pdf/1703.01488
        ''' prior_mean = self.prior_mean.expand_as(mu)
        prior_var = self.prior_var.expand_as(logvar)
        prior_logvar = self.prior_logvar.expand_as(logvar)
        var_division = logvar.exp() / prior_var # Σ_0 / Σ_1
        diff = mu - prior_mean # μ_1 - μ_0
        diff_term = diff *diff / prior_var # (μ_1 - μ_0)(μ_1 - μ_0)/Σ_1
        logvar_division = prior_logvar - logvar # log|Σ_1| - log|Σ_0| = log(|Σ_1|/|Σ_2|)
        # KL
        KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(1) - K) '''
        KLD = -0.5 * torch.sum(1+logvar - mu**2 - torch.exp(logvar), axis=1)
        return BCE + KLD
Deconv  torch.Size([1, 512])
Recon  torch.Size([1, 1, 256, 256])
Data  torch.Size([1, 1, 256, 256])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-33-5c58cc28c4e5> in <module>
    266     # 学習(Train)
    267     for epoch in range(1, 10):
--> 268         train(epoch)
    269         test(epoch)
    270         with torch.no_grad():

<ipython-input-33-5c58cc28c4e5> in train(epoch)
    226         optimizer.zero_grad()
    227         recon_batch, mu, logvar, gauss_z, dir_z = model(data)
--> 228         loss = model.loss_function(recon_batch, data, mu, logvar, 10)
    229         loss = loss.mean()
    230         loss.backward()

<ipython-input-33-5c58cc28c4e5> in loss_function(self, recon_x, x, mu, logvar, K)
    198         print('Recon ',recon_x.shape)
    199         print('Data ' ,x.shape)
--> 200         BCE = F.binary_cross_entropy(recon_x.view(-1, 65536), x.view(-1, 65536), reduction='sum')
    201         # ディリクレ事前分布と変分事後分布とのKLを計算
    202         # Calculating KL with Dirichlet prior and variational posterior distributions

C:\Conda5\lib\site-packages\torch\nn\functional.py in binary_cross_entropy(input, target, weight, size_average, reduce, reduction)
   2760         weight = weight.expand(new_size)
   2761 
-> 2762     return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
   2763 
   2764 

RuntimeError: all elements of input should be between 0 and 1```