Sickit Learn data split with torch data issue

I am trying to split my data using “train_test_split” of sklearn, but I got this error

“TypeError: take(): argument ‘index’ (position 1) must be Tensor, not numpy.ndarray”

here is the code:
train_0, X_test, y_train, y_test = train_test_split(zero_cond, data_labels, test_size=0.5)

====
Type of data:
print(zero_cond.type())
torch.FloatTensor
print(data_labels.type())
torch.LongTensor

If we read documentation, input type of the function is list or numpy array. Hence, you have 2 choice.

1 - If you read the dataset as a numpy array and then convert it to torch.Tensor, first split it, then convert it.
2 - If you read the dataset as a tensor (i don’t know how to read as a tensor :slight_smile: anyway), just convert to numpy array first (zero_cond.numpy()), then give to the function. Of course, you sould convert to tensor again.

1 Like

It seems this problem no longer appears in a higher version of Pytorch (>=1.8.0).