I ran your code on colab,and got the following error
cat() received an invalid combination of arguments - got (Tensor, Tensor, dim=int), but expected one of:
- (tuple of Tensors tensors, int dim, *, Tensor out)
- (tuple of Tensors tensors, name dim, *, Tensor out)