I am solving multi-class segmentation problem using u-net architecture. As specified in U-NET paper, I am trying to implement custom weight maps to counter class imbalances.
Below is the code for custom weight map-

from skimage.segmentation import find_boundaries
w0 = 10
sigma = 5
def make_weight_map(masks):
"""
Generate the weight maps as specified in the UNet paper
for a set of binary masks.
Parameters
----------
masks: array-like
A 3D array of shape (n_masks, image_height, image_width),
where each slice of the matrix along the 0th axis represents one binary mask.
Returns
-------
array-like
A 2D array of shape (image_height, image_width)
"""
nrows, ncols = masks.shape[1:]
masks = (masks > 0).astype(int)
distMap = np.zeros((nrows * ncols, masks.shape[0]))
X1, Y1 = np.meshgrid(np.arange(nrows), np.arange(ncols))
X1, Y1 = np.c_[X1.ravel(), Y1.ravel()].T
for i, mask in enumerate(masks):
# find the boundary of each mask,
# compute the distance of each pixel from this boundary
bounds = find_boundaries(mask, mode='inner')
X2, Y2 = np.nonzero(bounds)
xSum = (X2.reshape(-1, 1) - X1.reshape(1, -1)) ** 2
ySum = (Y2.reshape(-1, 1) - Y1.reshape(1, -1)) ** 2
distMap[:, i] = np.sqrt(xSum + ySum).min(axis=0)
ix = np.arange(distMap.shape[0])
if distMap.shape[1] == 1:
d1 = distMap.ravel()
border_loss_map = w0 * np.exp((-1 * (d1) ** 2) / (2 * (sigma ** 2)))
else:
if distMap.shape[1] == 2:
d1_ix, d2_ix = np.argpartition(distMap, 1, axis=1)[:, :2].T
else:
d1_ix, d2_ix = np.argpartition(distMap, 2, axis=1)[:, :2].T
d1 = distMap[ix, d1_ix]
d2 = distMap[ix, d2_ix]
border_loss_map = w0 * np.exp((-1 * (d1 + d2) ** 2) / (2 * (sigma ** 2)))
xBLoss = np.zeros((nrows, ncols))
xBLoss[X1, Y1] = border_loss_map
# class weight map
loss = np.zeros((nrows, ncols))
w_1 = 1 - masks.sum() / loss.size
w_0 = 1 - w_1
loss[masks.sum(0) == 1] = w_1
loss[masks.sum(0) == 0] = w_0
ZZ = xBLoss + loss
return ZZ

I am really not sure how to incorporate custom weights in cross entropy loss . I come across this example for keras but couldnâ€™t find any source for pytorch

The problem I am facing is to pass custom weight as an argument to CrossEntropyLoss, I have to pass masks as an argument like the the following method- criterion = torch.nn.CrossEntropyLoss(make_weight_map(mask))
which I guess is not the right way. I havenâ€™t approached such kind of problem earlier so pretty clueless right now

Do you want the contribution of a given prediction-target
pair to the loss to have a weight that depends (solely) on
the class of the target?

If so, I believe that the weight argument of CrossEntropyLoss
does what you want.

If this isnâ€™t the case, could you write out a simple equation that
shows what you want the loss to be and how your â€śweightsâ€ť
enter into it?

You make it sound like the weight argument doesnâ€™t do what
you want, but you donâ€™t say why. If you could explain how weight doesnâ€™t work, that might help make clear what it is
you want to achieve.

According to my understanding, If I need to set custom weight using the below way criterion = torch.nn.CrossEntropyLoss(make_weight_map(mask)), Iâ€™ll be using the masks outside of data-loaders which I am really not sure will work or not?

This ensures that the function will return a loss value for each element. You could then multiply the weights to each loss element.

gt # Ground truth, format torch.long
pd # Network output
W # per-element weighting based on the distance map from UNet
loss = criterion(pd, gt)
loss = W*loss # Ensure that weights are scaled appropriately
loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
loss = torch.mean(loss) # Average across a batch

Iâ€™ve observed that a weight of 0 could provide complications. It might be a good idea to add a small eps factor to W.

No, Each mask has at max 4 classes(Red,blue,green and black), I did what you told apart from the last statement loss = torch.mean(loss) because I am calculating the average across batch using train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss.data - train_loss)). I was encountering a error because my make_weight_map(masks) was expecting a numpy array instead of tensor and also I forgot to move tensor to cpu ,which I did now. but now, I guess my system has kind of stalled I guess.
Edit:I got the memory error now though my ram size is 32 gb.