RuntimeError: gather(): Expected dtype int64 for index

def select_beam_items(x, ids):
id_shape = list(ids.size())
id_rank = len(id_shape)
assert len(id_shape) == 2
x_shape = list(x.size())
x = torch.reshape(x, [batch_size, K] + x_shape[1:])
x_rank = len(x_shape) + 1
assert x_rank >= 2
if id_rank < x_rank:
ids = torch.reshape(
ids, id_shape + [1] * (x_rank - id_rank))
ids = ids.expand(id_shape + x_shape[1:])
x=x.type(torch.int64)
y = torch.gather(x, 1, ids)
y = torch.reshape(y, x_shape)
return y

My code was running earlier, but now I am getting RuntimeError: gather(): Expected dtype int64 for index

What I suspect is that ids are a float tensor and gather requires them to be integer could you try doing y = torch.gather(x, 1, ids.int()) to see if that is the case?

Also, to have better code formatting put the code into three ``` like ``` if x == 0: print 0 ``` when done on a code block it will retrain the indentations and help with the formatting.

def select_beam_items(x, ids):
id_shape = list(ids.size())
id_rank = len(id_shape)
assert len(id_shape) == 2
x_shape = list(x.size())
x = torch.reshape(x, [batch_size, K] + x_shape[1:])
x_rank = len(x_shape) + 1
assert x_rank >= 2
if id_rank < x_rank:
ids = torch.reshape(
ids, id_shape + [1] * (x_rank - id_rank))
ids = ids.expand(id_shape + x_shape[1:])
#ids=ids.type(torch.int64)
y = torch.gather(x, 1, ids.int64())
y = torch.reshape(y, x_shape)
return y

RuntimeError: gather(): Expected dtype int64 for index same error