RuntimeError: Expected object of type torch.LongTensor but found type torch.cuda.LongTensor for argument #3 'index'

@albanD @aerinykim Could you take a look at this please?