I know how to extract rows from multiple tensors to form a list of new tensors. However, I haven’t found a good way to do that for columns, so what I’ve done so far is put the tensor forward transpose.
a = torch.randn(3,5) b = torch.randn(3,5) c = torch.randn(3,5) orig = [a, b, c]
list_rows = list(zip(*orig))
Now, this is what I did with the columns
list(zip(*[torch.transpose(m, 0, 1) for m in orig]))
Is there a cleaner and more straightforward way to handle the column case?