Better torch.cat error message

It’d be nice if the error message for torch.cat was more specific. For example, when I run the following

import torch

a = torch.FloatTensor(1)
b = torch.DoubleTensor(1)
c = torch.cat((a, b))  # Error since a and b are different types

I get the following error

Traceback (most recent call last):
  File "foo.py", line 5, in <module>
    c = torch.cat((a, b))
TypeError: cat received an invalid combination of arguments - got (tuple), but expected one of:
 * (sequence[torch.FloatTensor] seq)
      didn't match because some of the arguments have invalid types: (tuple)
 * (sequence[torch.FloatTensor] seq, int dim)

It’d be nice if it said something like, “got mixed types but expected…” I spent a while double checking my inputs before I realized the types were different.

Looks like this confusion was also the cause for this post.

1 Like