def select_beam_items(x, ids):
id_shape = list(ids.size())
id_rank = len(id_shape)
assert len(id_shape) == 2
x_shape = list(x.size())
x = torch.reshape(x, [batch_size, K] + x_shape[1:])
x_rank = len(x_shape) + 1
assert x_rank >= 2
if id_rank < x_rank:
ids = torch.reshape(
ids, id_shape + [1] * (x_rank - id_rank))
ids = ids.expand(id_shape + x_shape[1:])
x=x.type(torch.int64)
y = torch.gather(x, 1, ids)
y = torch.reshape(y, x_shape)
return y
My code was running earlier, but now I am getting RuntimeError: gather(): Expected dtype int64 for index
What I suspect is that ids are a float tensor and gather requires them to be integer could you try doing y = torch.gather(x, 1, ids.int())
to see if that is the case?
Also, to have better code formatting put the code into three ```
like ``` if x == 0: print 0 ```
when done on a code block it will retrain the indentations and help with the formatting.
def select_beam_items(x, ids):
id_shape = list(ids.size())
id_rank = len(id_shape)
assert len(id_shape) == 2
x_shape = list(x.size())
x = torch.reshape(x, [batch_size, K] + x_shape[1:])
x_rank = len(x_shape) + 1
assert x_rank >= 2
if id_rank < x_rank:
ids = torch.reshape(
ids, id_shape + [1] * (x_rank - id_rank))
ids = ids.expand(id_shape + x_shape[1:])
#ids=ids.type(torch.int64)
y = torch.gather(x, 1, ids.int64())
y = torch.reshape(y, x_shape)
return y
RuntimeError: gather(): Expected dtype int64 for index same error
levitation
(Roland Pihlakas)
April 21, 2023, 12:03am
5
Use torch.gather(x, 1, ids.long())
instead. Note .long()
here.
2 Likes