# Classes, Accuracy, and Loss in Pixel-wise using Masks

Im trying to train the model using a UNet framework to detect cracks in roads. My images are grayscale between 0-1.0 with shape (batchsize, #classes, image height, image width). My masks are binary with 0 for background(I dont care about) and 1 for the crack sections. My first question is can I get away with using only 1 class or do I need to use 2? I’m not sure exactly how to calculate accuracy. Also I’m a little stuck on getting useful information from the loss and accuracy.

loss_function = nn.NLLLoss(reduction='none')     # reduction of none returns the same size image
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
initial = torch.rand([b_size, image_width, image_height]).to(device)  # b_size might be wrong


def fwd_pass(X, y):
    if train:
        model.zero_grad()       # Clears the gradients
    output = model(X)
    logp = F.log_softmax(output, dim=1)
    loss = loss_function(logp, y)
    # acc = ???
    if train:
        loss.backward(gradient=initial) 
        optimizer.step()
    return loss, # acc
'''
1 Like

Hello Ryan!

You’re basically on the right track. Let me make some suggestions
(but, note, you can usually do more-or-less the same thing in a
couple of different ways).

You are working on a binary segmentation problem. (This is a
binary classifier, where you classify individual pixels.) You could
treat this as the two-class case of a general multi-class problem,
but you’re marginally better off treating this as a binary problem.
So don’t explicitly use two classes.

The images you input to your model should have shape
[batchsize, height, width] – no class dimension,
because your input images don’t know anything about class.
The output of your model – which will be one of the inputs to
your loss function – should have the same shape – also no
class dimension (because were implementing this as a binary
problem). Your masks (sometimes called labels or targets) – the
other input to your loss function – should also have the same
shape, again no class dimension (the class information is contained
in the binary values of the mask pixels). And even though (in your
case) binary, your masks should be FloatTensors.

You should use BCEWithLogitsLoss. (NLLLoss is reasonable,
but is for the multi-class problem, and you’re better off going the
pure binary route.) The output of your model should be raw-score
logits that run from -inf to inf, so you shouldn’t pass them
through a non-linearity such as sigmoid().

Conceptually, you convert a prediction logit to a probability – to be
understood as the (predicted) probability of a pixel being in the
“1” (crack) class – by passing it though a sigmoid(). You would
typically convert a prediction probability to a yes-no prediction by
thresholding the probability – p > 0.5 means class-“1”.

Because sigmoid() maps a logit of 0.0 to a probability of 0.5,
you can simply compare your logits (the output of your model)
with a threshold of 0.0. These are your yes-no predictions for
your pixels, and your accuracy is typically taken to be the fraction
of pixels for which these prediction are correct, i.e., are equal to
your mask values.

One last comment: It may be the case that you have many more
“background” pixels than “crack” pixels. If so, you might want to
experiment with using a class weight in your loss function. (Look at
the pos_weight argument to BCEWithLogitsLoss constructor.)
But you should try things first with an unweighted loss function just
to see how well the plain-vanilla approach.

Best.

K. Frank

1 Like

I’m having trouble using the input image with shape of (batchsize, image_height, image_width). I’m getting “RuntimeError: Expected 4-dimensional input for 4-dimensional weight 32 1 3 3, but got 3-dimensional input of size [1, 256, 256] instead”. Probably something simple but I cant see it.

I deleted these two lines

image = np.reshape(image, (1, image_height, image_width))
out = output.squeeze(1)

And changed this line

loss = loss_function(out, y)

To

loss = loss_function(output, y)

This is most of my code.


k_size = 3
pad = 1
class UNetMini(Module):

    def __init__(self):
        super(UNetMini, self).__init__()

        self.block1 = Sequential(
            Conv2d(1, 32, kernel_size=k_size, padding=pad),
            ReLU(),
            Dropout2d(0.2),
            Conv2d(32, 32, kernel_size=k_size, padding=pad),
            ReLU(),
        )
        self.pool1 = MaxPool2d((2, 2))

        self.block2 = Sequential(
            Conv2d(32, 64, kernel_size=k_size, padding=pad),
            ReLU(),
            Dropout2d(0.2),
            Conv2d(64, 64, kernel_size=k_size, padding=pad),
            ReLU(),
        )
        self.pool2 = MaxPool2d((2, 2))

        self.block3 = Sequential(
            Conv2d(64, 128, kernel_size=k_size, padding=pad),
            ReLU(),
            Dropout2d(0.2),
            Conv2d(128, 128, kernel_size=k_size, padding=pad),
            ReLU()
        )

        self.up1 = UpsamplingNearest2d(scale_factor=2)
        self.block4 = Sequential(
            Conv2d(192, 64, kernel_size=k_size, padding=pad),
            ReLU(),
            Dropout2d(0.2),
            Conv2d(64, 64, kernel_size=k_size, padding=pad),
            ReLU()
        )

        self.up2 = UpsamplingNearest2d(scale_factor=2)
        self.block5 = Sequential(
            Conv2d(96, 32, kernel_size=k_size, padding=pad),
            ReLU(),
            Dropout2d(0.2),
            Conv2d(32, 32, kernel_size=k_size, padding=pad),
            ReLU()
        )

        self.conv2d = Conv2d(32, 1, kernel_size=pad)

    def forward(self, x):
        out1 = self.block1(x)
        out_pool1 = self.pool1(out1)

        out2 = self.block2(out_pool1)
        out_pool2 = self.pool1(out2)

        out3 = self.block3(out_pool2)

        out_up1 = self.up1(out3)
        # return out_up1
        out4 = torch.cat((out_up1, out2), dim=1)
        out4 = self.block4(out4)

        out_up2 = self.up2(out4)
        out5 = torch.cat((out_up2, out1), dim=1)
        out5 = self.block5(out5)

        out = self.conv2d(out5)

        return out


if rebuild_data:
    data = BuildData()
    data.make_training_data()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

processed_images = np.load("/home/ryan/CrackProject/DataSets/Data/images.npy", allow_pickle=True)
processed_masks = np.load("/home/ryan/CrackProject/DataSets/Data/masks.npy", allow_pickle=True)


model = UNetMini().to(device)


# This builds the data set
class FormsDataset(Dataset):
    def __init__(self, images, masks, transforms):
        self.images = images
        self.masks = masks
        self.transforms = transforms

    def __getitem__(self, idx):
        image = self.images[idx]
        image = image / 255
        image = np.reshape(image, (1, image_height, image_width)) 
        if self.transforms:
            image = self.transforms(image)

        mask = self.masks[idx]
        mask[mask > .7] = 1
        mask[mask <= .7] = 0
        if self.transforms:
            mask = self.transforms(mask)
        return image, mask

    def __len__(self):
        return len(self.images)


train_dataset = FormsDataset(processed_images, processed_masks, trans)
train_data_loader = DataLoader(train_dataset, batch_size=b_size, shuffle=False)    

# Training loop
total_steps = len(train_data_loader)
print(f'Train dataset has {len(train_data_loader)} batches of size {b_size}')

# Loss and Optimization
loss_function = nn.BCEWithLogitsLoss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
initial = torch.rand([b_size, image_height, image_width]).to(device)


def fwd_pass(X, y):
    if train:
        model.zero_grad()       # Clears the gradients from the last step, otherwise they would accumulate
    output = model(X)
    out = output.squeeze(1)
    loss = loss_function(out, y)
    acc = []
    if train:
        loss.backward(gradient=initial)  # comp derivative of the loss w.r.t.the parameters(or other req grad) backprop.
        optimizer.step()
    return acc, loss, output


def train():
    total_steps = len(train_data_loader)
    print(f"{epochs} epochs, {total_steps} total_steps per epoch")

    for epoch in tqdm(range(epochs), desc="Epochs"):
        for i, (images, masks) in enumerate(train_data_loader):
            images = images.type(torch.FloatTensor)
            images = images.to(device)

            masks = masks.type(torch.FloatTensor)
            masks = masks.to(device)

            acc, loss, output = fwd_pass(images, masks)

train()

Hello Ryan!

Since you’re using a version of “UNet” and the first layer of your
model is:

Conv2d(1, 32, kernel_size=k_size, padding=pad)

it looks like your model is expecting inputs of shape
[nBatch, nChannel, height, width], where nChannel = 1.
This makes sense for a grayscale image. (If your model were set up
for color images, you would probably have nChannel = 3 for the
three rgb channels.)

What was the shape of a batch of images to input to your model
before you did any reshaping, etc.? What was the shape of the
output of your model before reshaping? What is the shape of a
batch of masks that you pass to your loss function?

The point is that “UNet” typically carries along a “channel”
dimension, which, in your case, appears to be of size 1 for both
the input and output of your model.

BCEWithLogitsLoss requires that the shape of its input (the
output of your model) and its target (your mask) be the same.
They can both have this “singleton” nChannel = 1 dimension,
but, if so, they both have to have it.

This looks wrong. reduction='none' means don’t sum (or
average) the loss over the elements of output (and target).
But, for backpropagation, you want a single scalar loss, so you
should use the default reduction = 'mean'.

Best.

K. Frank

Thanks Frank. I finally got it working. My data set is still very small, around 100, so I’m guessing it is over fitting the data. Does anything stand out to you in the graphs.

Figure_1