I am tracing some model containing a consecutive call to split whose output is then passed to cat :
import torch from torch.fx import symbolic_trace class M(torch.nn.Module): def forward(self, x): y = torch.split(x, 4, 2) return torch.cat(y, 0) m = M() traced = symbolic_trace(m) print(traced.graph)
which, when executed triggers an error :
TypeError: cat() received an invalid combination of arguments - got (Proxy, int), but expected one of: * (tuple of Tensors tensors, int dim, *, Tensor out) * (tuple of Tensors tensors, name dim, *, Tensor out)
On other torch functions I have never seen such a problem. The difference I see here is that the intermediate value between split and cat is a sequence of tensors instead of a single tensor. And apparently this doesn’t play well under FX for some reason.
I am wondering if this is the desired behaviour and if I am missing something. Is there a workaround ?
Could the following issue : torch.cat does not call __torch_function__ properly · Issue #34294 · pytorch/pytorch · GitHub
be related to the behaviour I am observing ?