Create binary tensor from index tensor

I have an index tensor like we send input to a bert model after tokenization, so the shape is [batch_size, seq_len] and I have to return a tensor with 0’s and 1’s (1’s where the indices are matched) of shape [batch_size, seq_len, vocab_size].

The closest solution I could find is to use vmap.
I defined a function to take an index and dimension and output a torch.zeros of the given dim (which is a 1-dimensional for my case) and add 1 to a specific index.

def idx2tensor(val, dim):         # apply this across second dimension
            empty_tensor = torch.zeros(dim)
            empty_tensor[val] = 1

Now I am applying vmap for an input tensor of shape [batch_size, seq_length] and expect to get [batch_size, seq_length, dim].
batched_idx2ten = torch.vmap(idx2tensor, in_dims=(0, None)) following this example in docs.
Apply the batched_idx2ten to a random input and a fixed dim:

x = 10 * torch.rand(4, 10) - 1
x = torch.tensor(x, dtype=torch.int64)
batched_idx2ten(x, 15)

I am now stuck with the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-26-80c94fc2eaa3> in <cell line: 2>()
      1 x = 10 * torch.rand(4, 10) - 1
----> 2 batched_idx2ten(x, 15)

4 frames
/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py in wrapped(*args, **kwargs)
    186     # @functools.wraps(func)
    187     def wrapped(*args, **kwargs):
--> 188         return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    189 
    190     return wrapped

/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
    276 
    277     # If chunk_size is not specified.
--> 278     return _flat_vmap(
    279         func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
    280     )

/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in fn(*args, **kwargs)
     42     def fn(*args, **kwargs):
     43         with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 44             return f(*args, **kwargs)
     45     return fn
     46 

/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
    389     try:
    390         batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 391         batched_outputs = func(*batched_inputs, **kwargs)
    392         return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
    393     finally:

<ipython-input-25-c7783e629350> in idx2tensor(val, dim)
      1 def idx2tensor(val, dim):         # apply this across second dimension
      2     empty_tensor = torch.zeros(dim)
----> 3     empty_tensor[val] = 1
      4 
      5 batched_idx2ten = torch.vmap(idx2tensor, in_dims=(0, None))

RuntimeError: vmap: index_put_(self, *extra_args) is not possible because there exists a Tensor `other`
in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but
`self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of index_put_. 
If said operator is being called inside the PyTorch framework, please file a bug report instead.

Hi Satya!

I believe that you are looking for torch.nn.functional.one_hot():

>>> import torch
>>> torch.__version__
'2.2.2'
>>> _ = torch.manual_seed (2024)
>>> x = 10 * torch.rand(4, 10) - 1
>>> x = torch.tensor(x, dtype=torch.int64)
<stdin>:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
>>> x
tensor([[4, 7, 8, 0, 0, 2, 1, 5, 3, 6],
        [0, 4, 0, 6, 8, 3, 2, 8, 6, 5],
        [7, 8, 5, 5, 1, 4, 1, 6, 0, 0],
        [1, 0, 8, 5, 5, 3, 4, 0, 2, 4]])
>>> result = torch.nn.functional.one_hot (x, 15)
>>> result.shape
torch.Size([4, 10, 15])
>>> result[0]
tensor([[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]])

Best.

K. Frank

1 Like