Hi,
I am tracing some model containing a consecutive call to split whose output is then passed to cat :
import torch
from torch.fx import symbolic_trace
class M(torch.nn.Module):
def forward(self, x):
y = torch.split(x, 4, 2)
return torch.cat(y, 0)
m = M()
traced = symbolic_trace(m)
print(traced.graph)
which, when executed triggers an error :
TypeError: cat() received an invalid combination of arguments - got (Proxy, int), but expected one of:
* (tuple of Tensors tensors, int dim, *, Tensor out)
* (tuple of Tensors tensors, name dim, *, Tensor out)
On other torch functions I have never seen such a problem. The difference I see here is that the intermediate value between split and cat is a sequence of tensors instead of a single tensor. And apparently this doesn’t play well under FX for some reason.
I am wondering if this is the desired behaviour and if I am missing something. Is there a workaround ?
Could the following issue : torch.cat does not call __torch_function__ properly · Issue #34294 · pytorch/pytorch · GitHub
be related to the behaviour I am observing ?