Fx on split and cat combination

Hi,

I am tracing some model containing a consecutive call to split whose output is then passed to cat :

import torch
from torch.fx import symbolic_trace

class M(torch.nn.Module):
    def forward(self, x):
        y = torch.split(x, 4, 2)
        return torch.cat(y, 0)

m = M()
traced = symbolic_trace(m)
print(traced.graph)

which, when executed triggers an error :

TypeError: cat() received an invalid combination of arguments - got (Proxy, int), but expected one of:
 * (tuple of Tensors tensors, int dim, *, Tensor out)
 * (tuple of Tensors tensors, name dim, *, Tensor out)

On other torch functions I have never seen such a problem. The difference I see here is that the intermediate value between split and cat is a sequence of tensors instead of a single tensor. And apparently this doesn’t play well under FX for some reason.

I am wondering if this is the desired behaviour and if I am missing something. Is there a workaround ?

Could the following issue : torch.cat does not call __torch_function__ properly · Issue #34294 · pytorch/pytorch · GitHub
be related to the behaviour I am observing ?

2 Likes