I have an index tensor like we send input to a bert model after tokenization, so the shape is [batch_size, seq_len]
and I have to return a tensor with 0
’s and 1
’s (1
’s where the indices are matched) of shape [batch_size, seq_len, vocab_size]
.
The closest solution I could find is to use vmap.
I defined a function to take an index and dimension and output a torch.zeros
of the given dim (which is a 1-dimensional for my case) and add 1 to a specific index.
def idx2tensor(val, dim): # apply this across second dimension
empty_tensor = torch.zeros(dim)
empty_tensor[val] = 1
Now I am applying vmap for an input tensor of shape [batch_size, seq_length]
and expect to get [batch_size, seq_length, dim]
.
batched_idx2ten = torch.vmap(idx2tensor, in_dims=(0, None))
following this example in docs.
Apply the batched_idx2ten
to a random input and a fixed dim:
x = 10 * torch.rand(4, 10) - 1
x = torch.tensor(x, dtype=torch.int64)
batched_idx2ten(x, 15)
I am now stuck with the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-26-80c94fc2eaa3> in <cell line: 2>()
1 x = 10 * torch.rand(4, 10) - 1
----> 2 batched_idx2ten(x, 15)
4 frames
/usr/local/lib/python3.10/dist-packages/torch/_functorch/apis.py in wrapped(*args, **kwargs)
186 # @functools.wraps(func)
187 def wrapped(*args, **kwargs):
--> 188 return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
189
190 return wrapped
/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
276
277 # If chunk_size is not specified.
--> 278 return _flat_vmap(
279 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
280 )
/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in fn(*args, **kwargs)
42 def fn(*args, **kwargs):
43 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 44 return f(*args, **kwargs)
45 return fn
46
/usr/local/lib/python3.10/dist-packages/torch/_functorch/vmap.py in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
389 try:
390 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 391 batched_outputs = func(*batched_inputs, **kwargs)
392 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
393 finally:
<ipython-input-25-c7783e629350> in idx2tensor(val, dim)
1 def idx2tensor(val, dim): # apply this across second dimension
2 empty_tensor = torch.zeros(dim)
----> 3 empty_tensor[val] = 1
4
5 batched_idx2ten = torch.vmap(idx2tensor, in_dims=(0, None))
RuntimeError: vmap: index_put_(self, *extra_args) is not possible because there exists a Tensor `other`
in extra_args that has more elements than `self`. This happened due to `other` being vmapped over but
`self` not being vmapped over in a vmap. Please try to use out-of-place operators instead of index_put_.
If said operator is being called inside the PyTorch framework, please file a bug report instead.
Hi Satya!
I believe that you are looking for torch.nn.functional.one_hot()
:
>>> import torch
>>> torch.__version__
'2.2.2'
>>> _ = torch.manual_seed (2024)
>>> x = 10 * torch.rand(4, 10) - 1
>>> x = torch.tensor(x, dtype=torch.int64)
<stdin>:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
>>> x
tensor([[4, 7, 8, 0, 0, 2, 1, 5, 3, 6],
[0, 4, 0, 6, 8, 3, 2, 8, 6, 5],
[7, 8, 5, 5, 1, 4, 1, 6, 0, 0],
[1, 0, 8, 5, 5, 3, 4, 0, 2, 4]])
>>> result = torch.nn.functional.one_hot (x, 15)
>>> result.shape
torch.Size([4, 10, 15])
>>> result[0]
tensor([[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]])
Best.
K. Frank
1 Like