How to concatenate 2D tensor with 3D tensor without changing the content

torch.expand() is more efficient than torch.repeat() memory wise