Apply cross entropy loss with custom weight map

I am solving multi-class segmentation problem using u-net architecture. As specified in U-NET paper, I am trying to implement custom weight maps to counter class imbalances.
Below is the code for custom weight map-

from skimage.segmentation import find_boundaries

w0 = 10
sigma = 5

def make_weight_map(masks):
    Generate the weight maps as specified in the UNet paper
    for a set of binary masks.
    masks: array-like
        A 3D array of shape (n_masks, image_height, image_width),
        where each slice of the matrix along the 0th axis represents one binary mask.

        A 2D array of shape (image_height, image_width)
    nrows, ncols = masks.shape[1:]
    masks = (masks > 0).astype(int)
    distMap = np.zeros((nrows * ncols, masks.shape[0]))
    X1, Y1 = np.meshgrid(np.arange(nrows), np.arange(ncols))
    X1, Y1 = np.c_[X1.ravel(), Y1.ravel()].T
    for i, mask in enumerate(masks):
        # find the boundary of each mask,
        # compute the distance of each pixel from this boundary
        bounds = find_boundaries(mask, mode='inner')
        X2, Y2 = np.nonzero(bounds)
        xSum = (X2.reshape(-1, 1) - X1.reshape(1, -1)) ** 2
        ySum = (Y2.reshape(-1, 1) - Y1.reshape(1, -1)) ** 2
        distMap[:, i] = np.sqrt(xSum + ySum).min(axis=0)
    ix = np.arange(distMap.shape[0])
    if distMap.shape[1] == 1:
        d1 = distMap.ravel()
        border_loss_map = w0 * np.exp((-1 * (d1) ** 2) / (2 * (sigma ** 2)))
        if distMap.shape[1] == 2:
            d1_ix, d2_ix = np.argpartition(distMap, 1, axis=1)[:, :2].T
            d1_ix, d2_ix = np.argpartition(distMap, 2, axis=1)[:, :2].T
        d1 = distMap[ix, d1_ix]
        d2 = distMap[ix, d2_ix]
        border_loss_map = w0 * np.exp((-1 * (d1 + d2) ** 2) / (2 * (sigma ** 2)))
    xBLoss = np.zeros((nrows, ncols))
    xBLoss[X1, Y1] = border_loss_map
    # class weight map
    loss = np.zeros((nrows, ncols))
    w_1 = 1 - masks.sum() / loss.size
    w_0 = 1 - w_1
    loss[masks.sum(0) == 1] = w_1
    loss[masks.sum(0) == 0] = w_0
    ZZ = xBLoss + loss
    return ZZ

I am really not sure how to incorporate custom weights in cross entropy loss . I come across this example for keras but couldn’t find any source for pytorch

1 Like

Hello Rishav!

Does the weight argument described in the documentation for
CrossEntropyLoss do what you need?


K. Frank

Hi K. Frank,

The problem I am facing is to pass custom weight as an argument to CrossEntropyLoss, I have to pass masks as an argument like the the following method-
criterion = torch.nn.CrossEntropyLoss(make_weight_map(mask))
which I guess is not the right way. I haven’t approached such kind of problem earlier so pretty clueless right now

Hi Rishav!

I’m not clear on what you are asking.

Do you want the contribution of a given prediction-target
pair to the loss to have a weight that depends (solely) on
the class of the target?

If so, I believe that the weight argument of CrossEntropyLoss
does what you want.

If this isn’t the case, could you write out a simple equation that
shows what you want the loss to be and how your “weights”
enter into it?

You make it sound like the weight argument doesn’t do what
you want, but you don’t say why. If you could explain how
weight doesn’t work, that might help make clear what it is
you want to achieve.


K. Frank

Hi K. Frank,
I want to apply the following operation-

According to my understanding, If I need to set custom weight using the below way criterion = torch.nn.CrossEntropyLoss(make_weight_map(mask)), I’ll be using the masks outside of data-loaders which I am really not sure will work or not?

You could do the following:

criterion = torch.nn.CrossEntropy(reduction='none')

This ensures that the function will return a loss value for each element. You could then multiply the weights to each loss element.

gt # Ground truth, format torch.long
pd # Network output
W # per-element weighting based on the distance map from UNet
loss = criterion(pd, gt)
loss = W*loss # Ensure that weights are scaled appropriately
loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
loss = torch.mean(loss) # Average across a batch

I’ve observed that a weight of 0 could provide complications. It might be a good idea to add a small eps factor to W.

1 Like

But I am converting masks to classes while feeding them into dataloader.

Do you mean there is different mask for each class? If that’s the case then you’ll need to write your own cross-entropy loss function.

No, Each mask has at max 4 classes(Red,blue,green and black), I did what you told apart from the last statement loss = torch.mean(loss) because I am calculating the average across batch using train_loss = train_loss + ((1 / (batch_idx + 1)) * ( - train_loss)). I was encountering a error because my make_weight_map(masks) was expecting a numpy array instead of tensor and also I forgot to move tensor to cpu ,which I did now. but now, I guess my system has kind of stalled I guess.
Edit:I got the memory error now though my ram size is 32 gb.

okay, I raised my computation power to 61 GB Ram and 16 gb Tesla V100 gpu, and still got the memory error.I think the approach is wrong

Hi Rishav its been awhile so I guess you got this working. Let me know if you need help fixing this issue.