Key Error, custom pytorch dataset class

Hi all experts,

I’m new to pytorch and I’ve got a KEY ERROR issue below: Please help me solve it.

I have a created custom dataset class to access category columns for embedding and numerical columns separately, but I get key error when I try to iterate through elements of the data. Please see my code below:

class LoanDataset(Dataset):

def __init__(self, X, Y, emb_cols):
    X = X.copy()
    self.X1 = X.loc[:,emb_cols].copy().values.astype(np.int64) #categorical columns
    self.X2 = X.drop(columns=emb_cols).copy().values.astype(np.float32) #numerical columns
    self.y = Y
def __len__(self):
    return len(self.y)

def __getitem__(self, idx):
    return self.X1[idx], self.X2[idx], self.y[idx]

Next, I apply this to my x_train and x_test data like so:

train_ds = LoanDataset(x_train, y_train, emb_cols)

test_ds = LoanDataset(x_test, y_test, emb_cols)

batch_size = 64
traindl = DataLoader(x_train, batch_size=batch_size,shuffle=True)
testdl = DataLoader(x_test, batch_size=batch_size,shuffle=True)

I try to view elements of the train_ds with code below and get I key error:1

i = 1
for X1, X2, y in train_ds:
print(‘batch_num:’, i)
i = i+1

So I tried to view train_ds[1] and I get key error 1 as well. PLEASE HELP. Would appreciate your support, Thanks

It’s a little difficult to see what the exact problem is without knowing what the format of x_train, and y_train, etc. are when creating the dataset. It could just be that Y is a wrapped version of the data. You can check this via len(X1), len(X2), len(Y); if one is length 1 then the data needs to be unwrapped.

1 Like

hi eqy,
len(X1) = len(X2)=len(y) = 34056 samples/rows

x_train and y_train are data frames, created using sklearn train_test_split function

It might be that even if the len (e.g., of Y) is correct, it might not like indexing directly. Can you verify the types of X1, X2, y, to check if something like y.values is needed?

1 Like

Ok, found out that y is a pandas series while X1 and X2 are nd.arrays

1 Like

Right, so you might need something like the .values.astype as done on the other inputs if the dataframe doesn’t support indexing. You can also test indexing each of the self member variables directly to see which one(s) are causing the issue.

1 Like

after i changed y from pandas series to nd.array the for loop is running well with no errors, but I get error in my training loop now troubleshooting to see what the error is

1 Like