Torch tensor have some same behavior as python lists.
*tensor will unpack first dim of tensor into list of tensors (in your case list of one tensor with shape [4,5,6])
your code raises an error for first dimension grater than 1.
You can fix it as follows,
x = torch.ones((2,4,5,6))
m = torch.nn.Identity()
print(x.shape, m(x).shape,[out.shape for out in m([*x])])