target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10
output = torch.full([10, 64], 1.5) # A prediction (logit)
pos_weight = torch.ones() # All weights are equal to 1
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target) # -log(sigmoid(1.5))
This seems like it could be one of two scenarios. Either a single binary class, with 64 examples per batch, and a batch size of 10. Or, it could be 64 binary classes, with one example per batch, and a batch size of 10.
How would torch.nn.BCEWithLogitsLoss know which scenario this is? If I put in a single value, pos_weight=torch.tensor(3), would it automatically assume it’s the former from the length of the tensor? And if I use pos_weight = torch.ones(*3, BCEWithLogitsLoss will assume it’s the former?