I am getting this error when I run the PyTorch code, I try to change argmax to someother datatype but not working.
please help
def _mask(prev_generated_seq):
prev_mask = torch.eq(prev_generated_seq, 1)
lengths = torch.argmax(prev_mask,dim=1)
max_len = prev_generated_seq.size(1)
mask = []
for i in range(prev_generated_seq.size(0)):
if lengths[i] == 0:
mask_line = [0] * max_len
else:
mask_line = [0] * lengths[i].item()
mask_line.extend([1] * (max_len - lengths[i].item()))
mask.append(mask_line)
mask = torch.ByteTensor(mask)
if args.cuda:
mask = mask.cuda()
return prev_generated_seq.data.masked_fill_(mask, 0)
**Error**
File “main.py”, line 179, in
train_epoches(abstracts, model, config.epochs, teacher_forcing_ratio=1)
File “main.py”, line 155, in train_epoches
target_variables, model, teacher_forcing_ratio)
File “main.py”, line 139, in train_batch
prev_generated_seq = _mask(prev_generated_seq)
File “main.py”, line 101, in _mask
lengths = torch.argmax(prev_mask,dim=1)
RuntimeError: “argmax_cuda” not implemented for ‘Bool’