The docs state:
- pos_weight (Tensor, optional) – a weight of positive examples. Must be a vector with length equal to the number of classes.
https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
That seems to suggest you won’t need the torch.sum, and should only add on the batch dimension.
# Calculate pos_weight
negatives = img_y == 0
positives = img_y == 1
...
negs = torch.zeros((1, 256, 256)) #make sure the size here matches the dims and sizes for 1 image
posv = torch.zeros((1, 256, 256))
for i in range(len(train_y)):
image, mask, negatives, positives = dataset[i]
negs += negatives #if this has a batch dim, just append negatives.sum(dim=1)
posv += positives #see above note
pos_weight = negs/posv
See the comments, too.