It’d be nice if the error message for torch.cat was more specific. For example, when I run the following
import torch
a = torch.FloatTensor(1)
b = torch.DoubleTensor(1)
c = torch.cat((a, b)) # Error since a and b are different types
I get the following error
Traceback (most recent call last):
File "foo.py", line 5, in <module>
c = torch.cat((a, b))
TypeError: cat received an invalid combination of arguments - got (tuple), but expected one of:
* (sequence[torch.FloatTensor] seq)
didn't match because some of the arguments have invalid types: (tuple)
* (sequence[torch.FloatTensor] seq, int dim)
It’d be nice if it said something like, “got mixed types but expected…” I spent a while double checking my inputs before I realized the types were different.
Looks like this confusion was also the cause for this post.