Retaining grad_fn for one-hot encoded tensors

Hi everyone,

I’m using nn.functional.one_hot() to convert predicted class labels [N×1×H×W] to one-hot encoded tensors [N×C×H×W]. The goal is to calculate loss by a custom loss function.

The problem is nn.functional.one_hot() sets grad_fn = None so the calculated loss using such encoded predictions fails in backpropagation. How may I retain grad_fn for the one-hot encoded tensors?

Any thoughts are welcomed!

Hi Stark!

You can’t, because it doesn’t make sense to. one_hot() isn’t
usefully differentiable, so a loss function that uses* it won’t be either.

one_hot() takes a torch.int64 argument and returns a torch.int64
result. Pytorch doesn’t even permit such integer-valued tensors to
have requires_grad = True (because it doesn’t make sense).

*) To be more precise, a loss function could depend on the result
of one_hot() and also on the results of some differentiable tensor
operations. Such a loss function would be usefully differentiable
in that you could backpropagate through the differentiable tensor
operations. But you wouldn’t be able to backpropagate through the
one-hot() part of the computation (nor anything upstream of it).

Best.

K. Frank

1 Like

Thanks for your reply.

What’re the options if one wants to use Dice or IOU as a loss function? Using class indices doesn’t make sense when an ordinal relationship doesn’t exist between classes. The only reasonable solution is to use one-hot encoding, but it’s useless at the cost of losing gradients. Any other ideas?

Hi Stark!

Let me imagine that your use case is something like multi-class
semantic segmentation. Your target is an “image” of class labels,
that is this pixel is “class-3,” that pixel is “class-1,” and so on. The
prediction from your network is a set of nClass raw-score logits
for each pixel, with shape [nBatch, nClass, height, width].

Softmax (perhaps better named SoftOneHotEncodedArgMax)
will differentiably push your predictions towards 0.0 and 1.0.

So use soft_labels = torch.softmax (alpha * pred, dim = 1)

alpha can be used to tune how “hard” your soft one-hot labels
become – as you increase alpha you will push the labels closer
to being exactly zero or one. But don’t make alpha too large – if
the labels saturate and become exactly zero and one you will lose
useful differentiability, and as they get close to saturating, your
differentiability will become less and less useful.

Probably the most straightforward way to calculate your “soft”
intersection-over-union will be to one-hot your integer-class-label
target so that it also has shape [nBatch, nClass, height, width].
Letting pred be your soft-label predictions and targ be your
one-hotted class labels, then:

iou = (pred * targ).sum (dim = (2, 3)) / (pred + targ).sum (dim = (2, 3))

(If pred and targ were boolean – that is just 0 or 1 – pred * targ
would be the boolean intersection and pred + targ the union.)

iou will now have shape [nBatch, nClass] and each element of it
will be the soft IOU for a specific class and a specific sample in the
batch.

Best.

K. Frank

Your explanation makes sense for multi-class semantic segmentation.

What if I want predicted class labels [N×1×H×W] from torch.sigmoid(logits) either as 0 or 1 and shape as [N×C×H×W] for a hard loss function? All such attempt using thresholding, torch.one_hot(),… fails to retain the gradient fucntion.

Hi Stark!

If by this you mean that you wish to perform binary semantic
segmentation, then you would do essentially as I described
above, but use torch.sigmoid (alpha * logits) (rather than
torch.softmax()). And, analogously, the larger you make your
choice of alpha, the “harder” your binary predictions will become.

Yes, if you insist on a hard loss function, that is, something like
integer_valued_intersection / integer_valued_union,
you will, as you have seen, lose the gradients.

When you turn something that feeds into your loss function into an
integer, you’ve applied a discontinuous function that isn’t usefully
differentiable. So you need to design a “soft” version of your desired
loss function through which you can backpropagate.

As an aside, why not use BCEWithLogitsLoss (on a per-pixel basis)
for your binary semantic segmentation? People often use this loss
function to good effect.

Best.

K. Frank

Hi Frank,

That’s right. The use of soft predictions with hard targets in binary/semantic segmentation problems is commonly practiced.

Please share your thoughts on having a soft target in a soft loss function. Could it be a better option? I will be more than happy if you suggest some relevant work on this.

Hi Stark!

This is a perfectly reasonable thing to do. BCEWithLogitsLoss (as
well as BCELoss) supports soft targets – you just pass them in as
the target. You can certainly contemplate using soft targets with
other “soft” loss functions, with the details depending on the specific
loss function.

(As an aside, it also makes sense to use soft labels with cross-entropy.
Pytorch’s CrossEntropyLoss doesn’t support soft labels, but it is
straightforward to write your own soft-label version of cross-entropy.)

Probably not (but it depends).

Your targets are what you know about your training data. If you
have (correct) hard class-“0” / class-“1” labels there is no benefit
in artificially fuzzing them into soft labels – you would just be throwing
away information. On the other hand, if your labels are naturally
soft, you probably make things worse by making them artificially
hard.

To illustrate what I mean, consider training a binary cat-dog classifier.
If all of your training images are clear cut – “Yup, that’s a cat; yup
that’s a dog.” – then you should use hard labels because doing so
uses all of the information you have.

But if some of your images are truly uncertain – maybe they are blurry
or partially obscured – you should use soft labels, because that’s what
you actually know.

In practice, why might you have soft labels and how would you get
them? You could, for example, have 100 people label each image.
If 93 people label a specific image as “dog” (and 3 as “cat”), you
would give that image of soft label of 0.93. If 24 (of the 100) people
label an image as “dog,” you would give it a soft label of 0.24.

The fact that people don’t always agree about the images indicates
that the image labels are uncertain, that you might want to use soft
labels, and gives you a quantitative way to assign such soft labels.

Best.

K. Frank

2 Likes