yangxi
(Xi Yang)
August 15, 2019, 11:00am
1
What I want is fairly simple: a MSE loss function, but able to mask some items:
def masked_mse_loss(a, b, mask):
sum2 = 0.0
num = 0
for i in len(range(a)):
if mask[i] == 1:
sum2 += (a[i] - b[i]) ** 2.0
num += 1
return sum2 / num
Due to backward issue, I believe such a straightforward implementation would not make pytorch happy, but I have no idea on how to make it in correct way.
2 Likes
Oli
(Olof Harrysson)
August 15, 2019, 8:40pm
2
I was under the belief that pytorch would accept this kind of loss function. What error does it give you?
yangxi
(Xi Yang)
August 16, 2019, 9:20am
3
I wrote the loss class:
class MaskedMSELoss(torch.nn.Module):
def __init__(self):
super(MaskedMSELoss, self).__init__()
def forward(self, input, target, mask):
diff2 = (torch.flatten(input) - torch.flatten(target)) ** 2.0
sum2 = 0.0
num = 0
flat_mask = torch.flatten(mask)
assert(len(flat_mask) == len(diff2))
for i in range(len(diff2)):
if flat_mask[i] == 1:
sum2 += diff2[i]
num += 1
return sum2 / num
What would be default backward looks like?
1 Like
zhl515
August 17, 2019, 10:43am
4
You can call out.backward() to backpropagate the error.
predict = torch.tensor([1.0, 2, 3, 4], dtype=torch.float64, requires_grad=True)
target = torch.tensor([1.0, 1, 1, 1], dtype=torch.float64, requires_grad=True)
mask = torch.tensor([1, 0, 0, 1], dtype=torch.float64, requires_grad=True)
out = torch.sum(((predict-target)*mask)**2.0) / torch.sum(mask)
out.backward()
6 Likes
yangxi
(Xi Yang)
August 17, 2019, 12:11pm
5
I modified the algorithm to a more simple form that removed for cycle and if, and it works properly:
diff2 = (torch.flatten(input) - torch.flatten(target)) ** 2.0 * torch.flatten(mask)
result = torch.sum(diff2) / torch.sum(mask)
return result
Thanks!