Convert int into one-hot format

(movefast) #22

Run into the issue myself and did some searching, torch.sparse.torch.eye(num_labels).index_select(dim=0, index=labels) also seems to work pretty well in addition to the scatter_ solution in the 0.3 release.

(Sajid Iqbal) #23

def get_one_hot(preds,gt):
encoded_target =
target = gt.unsqueeze(1) # now target is in shape [BCHW]=[20,1,240,240]
unseq = target.long()
unseq =

# encoded_target.scatter_(dim,index,val)
# unseq dim 'dim' must be 1
encoded_target.scatter_(1, unseq, 1)
return encoded_target

It returns the one hot encoding of the target. In my case the target was of shape [1,1,240,240] and preds of shape [1,5,240,240]

(Justheuristic) #24

Here’s a tensorflow-like solution based on previous code in this thread

def to_one_hot(y, n_dims=None):
    """ Take integer y (tensor or variable) with n dims and convert it to 1-hot representation with n+1 dims. """
    y_tensor = if isinstance(y, Variable) else y
    y_tensor = y_tensor.type(torch.LongTensor).view(-1, 1)
    n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
    y_one_hot = torch.zeros(y_tensor.size()[0], n_dims).scatter_(1, y_tensor, 1)
    y_one_hot = y_one_hot.view(*y.shape, -1)
    return Variable(y_one_hot) if isinstance(y, Variable) else y_one_hot

(Guillaume Dumont) #25

It is also possible to abuse broadcasting and do:

# some labels
labels = torch.arange(3)
labels = labels.reshape(3, 1)

num_classes = 4
one_hot_target = (labels == torch.arange(num_classes).reshape(1, num_classes)).float()


 1  0  0  0
 0  1  0  0
 0  0  1  0
[torch.FloatTensor of size (3,4)]

(Rishabh Agrahari) #26

You can use torch.eye function for this:

def one_hot_embedding(labels, num_classes):
    """Embedding labels to one-hot form.

      labels: (LongTensor) class labels, sized [N,].
      num_classes: (int) number of classes.

      (tensor) encoded labels, sized [N, #classes].
    y = torch.eye(num_classes) 
    return y[labels] 

This should help!