Mask selection with expand


(Pierre Antoine Ganaye) #1

Hi,
I have a target tensor that I want to one hot encode and finally filter with a binary mask originally based on the target tensor.

# target tensor
t = torch.Tensor(4, 5, 5).random_(0, 10).long()
# binary mask
m = t == 3
# one hot encoded tensor
et = torch.Tensor(4, 10, 5, 5).zero_()
# one hot encoding
et.scatter_(1, t.unsqueeze(1), 1)
# expand the binary mask to match the new one-hot encoded target
m = m.unsqueeze(1).expand(et.size())

# should print a vector of ones, but does not
print(et[~m].view(10, -1).sum(0))

After filtering with the mask, I cannot manage to find a correct view for having back to encoded targets.
Thanks


(Clément Pinard) #2

You will have problems regarding one of the inner dimensions, as you won’t have always the same number of elements
here’s a little over-simplified example :

t = torch.Tensor(2, 2).random_(0, 2).long()
# t = [ 0, 0,
#       1, 0] (long tensor)
m = t == 1
# m = [[0, 0],
#      [1, 0]]  (byte tensor this time)
et = torch.Tensor(2, 2, 2).zero_()
et.scatter_(1, t.unsqueeze(1), 1)
# et = [[[1, 1],
#        [0, 0]],
#       [[0, 1],
#        [1, 0]]] (float)
m = m.unsqueeze(1).expand(et.size())

# ~m = [[[1, 1],
#        [1, 1]],
#       [[0, 1],
#        [0, 1]]] (byte)

#et[~m] = [1,1,0,0,1,0]

as you can see, et still has the information you need, but because everything is flatten, you lose dimensional information. What you want is to get

et[~m] = [[[1,0],
           [1,0]],
          [[1,0]]]

And this not a regular array as second dimension is 2 then 1

The right way to do it is making a one hot encoding for the first dimension :


# target tensor
t = torch.Tensor(4, 5, 5).random_(0, 10).long()
# binary mask
m = t == 3
# one hot encoded tensor
et = torch.Tensor(10, 4, 5, 5).zero_()
# one hot encoding
et.scatter_(0, t.unsqueeze(0), 1)
# expand the binary mask to match the new one-hot encoded target
m = m.unsqueeze(0).expand(et.size())

# prints a vector of ones !
print(et[~m].view(10, -1).sum(0))

If you still want to keep batch wise operation, you actually need to treat each batch separately with a loop because batches won’t necessarily be the same size


(Pierre Antoine Ganaye) #3

Thanks for the detailed answer @ClementPinard. Your answer is clearly solving the problem I posted however is does not incorporate well with my original goal which is implementing a dice loss. I need to filter out invalid label, it’s the reason I need to apply a mask on the one-hot encoded labels, but I also need to filter the output tensor. However I don’t think I can do the same array modification and keep autograd working correctly. Here is the original code

def dice_loss3(output, target, weights=1, ignore_index=None):
    output = output.exp()
    encoded_target = output.data.clone().zero_()
    if ignore_index is not None:
        # mask of invalid label
        mask = target == ignore_index
        # clone target to not affect the variable ?
        filtered_target = target.clone()
        # replace invalid label with whatever legal index value
        filtered_target[mask] = 0
        # one hot encoding
        encoded_target.scatter_(1, filtered_target.unsqueeze(1), 1)
        # expand the mask for the encoded target array
        mask = mask.unsqueeze(1).expand(output.data.size())
    else:
        encoded_target.scatter_(1, target.unsqueeze(1), 1)
    encoded_target = Variable(encoded_target)

    assert output.size() == encoded_target.size(), "Input sizes must be equal."
    assert output.dim() == 4, "Input must be a 4D Tensor."

    o = output[~mask].view(-1, output.size(1))
    t = encoded_target[~mask].view(-1, output.size(1))
    numerator = o * t
    denominator = o.pow(2) + t

    dice = 2 * (numerator / denominator) * weights
    return dice.sum() 

I tried to adapt :

def dice_loss3(output, target, weights=1, ignore_index=None):
    b, c, h, w = output.size()
    output = output.exp()
    output = output.permute(1, 0, 2, 3)
    encoded_target = output.data.new(c, b, h, w).zero_()
    if ignore_index is not None:
        # mask of invalid label
        mask = target == ignore_index
        # clone target to not affect the variable ?
        filtered_target = target.clone()
        # replace invalid label with whatever legal index value
        filtered_target[mask] = 0
        # one hot encoding
        encoded_target.scatter_(0, filtered_target.unsqueeze(0), 1)
        # expand the mask for the encoded target array
        mask = mask.unsqueeze(0).expand(encoded_target.size())
    else:
        encoded_target.scatter_(1, target.unsqueeze(1), 1)
    encoded_target = Variable(encoded_target)

    assert output.size() == encoded_target.size(), "Input sizes must be equal."
    assert output.dim() == 4, "Input must be a 4D Tensor."

    o = output[~mask].view(output.size(0), -1)
    t = encoded_target[~mask].view(output.size(0), -1)
    numerator = o * t
    denominator = o.pow(2) + t

    dice = 2 * (numerator.sum(1) / denominator.sum(1)) * weights
    return dice.sum()

This is clearly wrong as I am having dice values over 5000, it shouldn’t be higher than 1. The formula is simple (output*one_hot_targets) / (output.pow(2) + one_hot_targets)

If you have any idea what could be wrong, it would be really nice, I tried so many hacks to solve this… Thanks


(Pierre Antoine Ganaye) #4

Actually I think even the formula is broken, by simply assigning my ignore_label to another label it should reduce the overall dice value. I tested and I am still having dice > 1.


(Pierre Antoine Ganaye) #5

However this is still a measure of similarity, I don’t see how to maximise this term concurrently with the cross entropy loss, if it was the real dice I would have done this:
ce(output, target) + (1-dice(output, target))


(Clément Pinard) #6

What about weighting the dice loss with zero values when you have to ignore it :


def dice_loss(output, target, weights=1, ignore_index=None):
    b, c, h, w = output.size()
    output = output.exp()
    encoded_target = output.data.new(b, c, h, w).zero_()
    encoded_target.scatter_(1, target.unsqueeze(0), 1)
    encoded_target = Variable(encoded_target)
    mask = None
    if ignore_index is not None:
        # mask of invalid label
        mask = (target == ignore_index).float()
        # expand the mask for the encoded target array
        mask = Variable(mask.unsqueeze(1))


    assert output.size() == encoded_target.size(), "Input sizes must be equal."
    assert output.dim() == 4, "Input must be a 4D Tensor."

    o = output
    t = encoded_target
    numerator = o * t
    denominator = o.pow(2) + t

    dice = 2 * (numerator / denominator) * weights  # dimensions are still [B, C, H, W] at this point
    if mask is not None:
        return (dice*mask).sum()  # array broadcasting is awesome !
    return dice.sum() 

(Pierre Antoine Ganaye) #7

I hoped so but it’s not possible for two reasons :

  • ignore_index value is negative and scatter won’t allow a negative index. Remapping to a new label would mean having to reshape the output tensor too.
  • when doing numerator = output * target the ignore_index labels and outputs are taken into account.

(Clément Pinard) #8

How can ignore_index be negative ? I may not have grasped what it is about, but shouldn’t it be an integer value between 0 and C-1 ?

numerator = output * target is an element-wise operation and has the same size as the mask, that means if you weight it some of the terms with 0 before doing the sum, you’ll be fine as gradient w.r.t these terms will be zero. the only downside is that it actually computes everything before throwing out the ignored indices, but (at least to me) it’s much simpler to read, depends on the ratio of samples you throw out.


(Pierre Antoine Ganaye) #9

You are right, weighting to 0 is also a good option, that what I did first, however I thought using mask would be simpler and faster.

def dice_loss(output, target, weights=1, ignore_index=None):
    output = output.exp()
    encoded_target = output.data.clone().zero_()
    if ignore_index is not None:
        # mask of invalid label
        mask = target == ignore_index
        # clone target to not affect the variable ?
        filtered_target = target.clone()
        # replace invalid label with whatever legal index value
        filtered_target[mask] = 0
        # one hot encoding
        encoded_target.scatter_(1, filtered_target.unsqueeze(1), 1)
        # expand the mask for the encoded target array
        mask = mask.unsqueeze(1).expand(output.data.size())
        # apply 0 to masked pixels
        encoded_target[mask] = 0
    else:
        encoded_target.scatter_(1, target.unsqueeze(1), 1)
    encoded_target = Variable(encoded_target)

    assert output.size() == encoded_target.size(), "Input sizes must be equal."
    assert output.dim() == 4, "Input must be a 4D Tensor."

    numerator = (output * encoded_target).sum(dim=3).sum(dim=2)
    denominator = output.pow(2) + encoded_target
    if ignore_index is not None:
        # exclude masked values from den1
        denominator[mask] = 0

    dice = 2 * (numerator / denominator.sum(dim=3).sum(dim=2)) * weights
    return dice.sum() / dice.size(0)

Ignore_index is negative, it’s also the value specified to nll_loss. Those labels are ignored from the loss.
This function tends to plus infinity if the similarity improves, I can’t think of a way to optimize this, loss = -dice(output, target) would mean having a negative loss from the start, it would not converge as the dice would tend to minus infinity…


(Pierre Antoine Ganaye) #10

Thank you for your time @ClementPinard, you help was really appreciated ! I mark this subject as solved.


(Clément Pinard) #11

You’re welcome, I just encountered your comment here, so I understand the problem better, didn’t kow about dice loss before.

some remarks on dice loss : at the very best, maximum is reached when output = target, which means
dice = 2* ((target**2)/(target**2 + target))
target is only composed of 0s and 1s, so target **2 is equal to target which brings us to dice = 1

at the very worst, output*target is 0 and dice is 0.

are you sure output has it’s values between 0 and 1 ? I see output.exp() , shouldn’t it be torch.nn.functional.softmax(output) ? (maybe your output has been logsoftmaxed before though…)


(Pierre Antoine Ganaye) #12

Your logic is right, the dice for a binary problem is defined as
dice = 2* (output * target).sum() / (output.sum() + target.sum())
where output * target is the intersection of the two sets.

You are correct about the output values, I use output.exp() because output = logsoftmax(output) so logsoftmax(output).exp() = softmax(output). I finally got something to work with ignore_label.

def dice_loss(output, target, weights=None, ignore_index=None):
    smooth = 1.
    loss = 0.

    output = output.exp() # because output = logsoftmax(output)
    encoded_target = output.data.clone().zero_()
    if ignore_index is not None:
        mask = target == ignore_index
        target = target.clone()
        target[mask] = 0
        encoded_target.scatter_(1, target.unsqueeze(1), 1)
        mask = mask.unsqueeze(1).expand_as(encoded_target)
        encoded_target[mask] = 0
    else:
        encoded_target.scatter_(1, target.unsqueeze(1), 1)
    encoded_target = Variable(encoded_target)

    if weights is None:
        weights = Variable(torch.ones(output.size(1)).type_as(output.data))

    intersection = output * encoded_target
    numerator = 2 * intersection.sum(3).sum(2).sum(0) + smooth
    denominator = (output + encoded_target).sum(3).sum(2).sum(0) + smooth
    loss_per_channel = weights * (1 - (numerator / denominator))

    return loss_per_channel.sum() / output.size(1)

Though the code is slow.


(Clément Pinard) #13

try this version :


def dice_loss(output, target, weights=None, ignore_index=None):
    smooth = 1.
    loss = 0.

    output = output.exp() # because output = logsoftmax(output)
    encode_target = output.detach()*0
    if ignore_index is not None:
        mask = target == ignore_index
        target = target.clone()
        target[mask] = 0
        encoded_target.scatter_(1, target.unsqueeze(1), 1)
        mask = Variable(mask.unsqueeze(1)).type_as(encoded_target)  # no need to expand, thanks to array broadcasting
        output *= mask
        encoded_target *= mask
    else:
       encoded_target.scatter_(1, target.unsqueeze(1),1)

    if weights is None:
        weights = Variable(torch.ones(1, output.size(1)).type_as(output))
    intersection = output * encoded_target
    numerator = 2 * intersection.sum(3).sum(2) + smooth
    denominator = (output + encoded_target).sum(3).sum(2) + smooth
    loss_per_channel = weights * (1 - (numerator / denominator))

    return loss_per_channel.sum() / output.size(1)

Some comments (others might correct me):

  • AFAIK Variable creation and typing is expensive but can be avoided using the detach() function (you are basically cloning a Variable without its dynamic graph)
  • FloatTensor[ByteTensor] is also expensive (be it on cpu or gpu). I understand the first one is not avoidable but the rest can be avoided using multiplications with super fast BLAS implementation.
  • .sum(0) can be avoided and absorbed by the last sum. not sure it actually speeds things up but I personnally prefer to keep things batch-wise the longest possible for the sake of readability

If things are still slow, you can try to profile the time used for each function, be sure to set cuda calls to be blocking :

CUDA_LAUNCH_BLOCKING=1 python -m cProfile -o myscript.cprof myscript.py

More info here and here (I personally like SnakeViz)


(Pierre Antoine Ganaye) #14

I compared the output of your code with my current implementation, it does not yield the same value, after searching for the difference, I found the bug is here

mask = Variable(mask.unsqueeze(1)).type_as(encoded_target)  # no need to expand, thanks to array broadcasting
output *= mask
encoded_target *= mask # not equivalent to encoded_target[mask] = 0

(Clément Pinard) #15

Yeah you need to type
mask = Variable(~mask.unsqueeze(1)).type_as(encoded_target)

sorry about that typo


(Pierre Antoine Ganaye) #16

Oh yes, correct ! didn’t even noticed the problem… I obtain the same results finally, thanks !


(Clément Pinard) #17

and is it faster ? (I think it is, but I didn’t check)


(Pierre Antoine Ganaye) #18

Unfortunately no, it’s at least the same speed and sometimes longer.