What is the equivalent of tf.unstack()?

unstack() in tensorflow is the inverse of stack() (the latter of which is also available in pytorch). I can imagine that there is a function in pytorch, that takes for instance a MxN tensor and returns a list of M N-dimensional tensors (the rows) or N M-dimensional tensors (the columns), but I couldn’t find it.

TF documentation of this function is here:

Is there a similar function in pytorch?

1 Like

torch.unbind http://pytorch.org/docs/master/torch.html?highlight=unbind#torch.unbind