If we read documentation, input type of the function is list or numpy array. Hence, you have 2 choice.

1 - If you read the dataset as a numpy array and then convert it to torch.Tensor, first split it, then convert it.
2 - If you read the dataset as a tensor (i don’t know how to read as a tensor anyway), just convert to numpy array first (zero_cond.numpy()), then give to the function. Of course, you sould convert to tensor again.