# Apply cross entropy loss with custom weight map

I am solving multi-class segmentation problem using u-net architecture. As specified in U-NET paper, I am trying to implement custom weight maps to counter class imbalances.
Below is the code for custom weight map-

``````from skimage.segmentation import find_boundaries

w0 = 10
sigma = 5

"""
Generate the weight maps as specified in the UNet paper
for a set of binary masks.

Parameters
----------
A 3D array of shape (n_masks, image_height, image_width),
where each slice of the matrix along the 0th axis represents one binary mask.

Returns
-------
array-like
A 2D array of shape (image_height, image_width)

"""
distMap = np.zeros((nrows * ncols, masks.shape[0]))
X1, Y1 = np.meshgrid(np.arange(nrows), np.arange(ncols))
X1, Y1 = np.c_[X1.ravel(), Y1.ravel()].T
# find the boundary of each mask,
# compute the distance of each pixel from this boundary
X2, Y2 = np.nonzero(bounds)
xSum = (X2.reshape(-1, 1) - X1.reshape(1, -1)) ** 2
ySum = (Y2.reshape(-1, 1) - Y1.reshape(1, -1)) ** 2
distMap[:, i] = np.sqrt(xSum + ySum).min(axis=0)
ix = np.arange(distMap.shape[0])
if distMap.shape[1] == 1:
d1 = distMap.ravel()
border_loss_map = w0 * np.exp((-1 * (d1) ** 2) / (2 * (sigma ** 2)))
else:
if distMap.shape[1] == 2:
d1_ix, d2_ix = np.argpartition(distMap, 1, axis=1)[:, :2].T
else:
d1_ix, d2_ix = np.argpartition(distMap, 2, axis=1)[:, :2].T
d1 = distMap[ix, d1_ix]
d2 = distMap[ix, d2_ix]
border_loss_map = w0 * np.exp((-1 * (d1 + d2) ** 2) / (2 * (sigma ** 2)))
xBLoss = np.zeros((nrows, ncols))
xBLoss[X1, Y1] = border_loss_map
# class weight map
loss = np.zeros((nrows, ncols))
w_1 = 1 - masks.sum() / loss.size
w_0 = 1 - w_1
ZZ = xBLoss + loss
return ZZ
``````

I am really not sure how to incorporate custom weights in `cross entropy loss` . I come across this example for keras but couldnâ€™t find any source for pytorch

1 Like

Hello Rishav!

Does the `weight` argument described in the documentation for
`CrossEntropyLoss` do what you need?

Best.

K. Frank

Hi K. Frank,

The problem I am facing is to pass custom weight as an argument to `CrossEntropyLoss`, I have to pass masks as an argument like the the following method-
`criterion = torch.nn.CrossEntropyLoss(make_weight_map(mask))`
which I guess is not the right way. I havenâ€™t approached such kind of problem earlier so pretty clueless right now

Hi Rishav!

Iâ€™m not clear on what you are asking.

Do you want the contribution of a given prediction-target
pair to the loss to have a weight that depends (solely) on
the class of the target?

If so, I believe that the `weight` argument of `CrossEntropyLoss`
does what you want.

If this isnâ€™t the case, could you write out a simple equation that
shows what you want the loss to be and how your â€śweightsâ€ť
enter into it?

You make it sound like the `weight` argument doesnâ€™t do what
you want, but you donâ€™t say why. If you could explain how
`weight` doesnâ€™t work, that might help make clear what it is
you want to achieve.

Best.

K. Frank

Hi K. Frank,
I want to apply the following operation-

According to my understanding, If I need to set custom weight using the below way `criterion = torch.nn.CrossEntropyLoss(make_weight_map(mask))`, Iâ€™ll be using the masks outside of data-loaders which I am really not sure will work or not?

You could do the following:

``````criterion = torch.nn.CrossEntropy(reduction='none')
``````

This ensures that the function will return a loss value for each element. You could then multiply the weights to each loss element.

``````gt # Ground truth, format torch.long
pd # Network output
W # per-element weighting based on the distance map from UNet
loss = criterion(pd, gt)
loss = W*loss # Ensure that weights are scaled appropriately
loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
loss = torch.mean(loss) # Average across a batch
``````

Iâ€™ve observed that a weight of 0 could provide complications. It might be a good idea to add a small `eps` factor to `W`.

1 Like

But I am converting masks to classes while feeding them into dataloader.

Do you mean there is different mask for each class? If thatâ€™s the case then youâ€™ll need to write your own cross-entropy loss function.

No, Each mask has at max 4 classes(Red,blue,green and black), I did what you told apart from the last statement `loss = torch.mean(loss)` because I am calculating the average across batch using `train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.data - train_loss))`. I was encountering a error because my `make_weight_map(masks)` was expecting a numpy array instead of tensor and also I forgot to move tensor to cpu ,which I did now. but now, I guess my system has kind of stalled I guess.
Edit:I got the memory error now though my ram size is 32 gb.

okay, I raised my computation power to 61 GB Ram and 16 gb Tesla V100 gpu, and still got the memory error.I think the approach is wrong

Hi Rishav its been awhile so I guess you got this working. Let me know if you need help fixing this issue.