Hi all,
as per title I am trying to implement cross-validation in PyTorch; I know that skorch lets you run PyTorch code in Sklearn functions but I am still curious to know how this could be done with in simple PyTorch code.
I looked at other post and put together what I think could be an ok way to do this, but I would love to know other people’s opinion and if this looks correct:
X = main_df.drop(columns=['target'])
y = main_df['target']
X_train_t = torch.tensor(X.values, dtype=torch.float32)
y_train_t = torch.tensor(y.values, dtype=torch.float32)
cv_dataset = TensorDataset(X_train_t, y_train_t)
class NetModel(nn.Module):
def __init__(self, in_count):
super(NetModel, self).__init__()
self.layers = nn.Sequential(
nn.Linear(in_count, 32),
nn.ReLU(),
nn.Linear(32, 1)
)
def forward(self, x):
out = self.layers(x)
return out
num_epochs = 1
kfold = KFold(n_splits=5, shuffle=True, random_state=42)
tot_rmse = 0.0
for fold, (train_ids, test_ids) in enumerate(kfold.split(cv_dataset)):
print(f'FOLD {fold+1}')
train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
train_loader = DataLoader(cv_dataset, batch_size=5, sampler=train_subsampler)
test_loader = DataLoader(cv_dataset, batch_size=5, sampler=test_subsampler)
# initialize model
model = NetModel(X_train_t.shape[1])
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
for epoch in range(num_epochs):
current_loss = 0.0
model.train()
for i, data in enumerate(train_loader, 0):
inputs, targets = data
targets = targets.reshape((targets.shape[0], 1))
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
fold_rmse = 0.0
model.eval()
with torch.no_grad():
for i, data in enumerate(test_loader, 0):
inputs, targets = data
targets = targets.reshape((targets.shape[0], 1))
outputs = model(inputs)
rmse = root_mean_squared_error(targets, outputs)
# accumulate rmse of the fold
fold_rmse += rmse
print(f"Sum of fold {fold+1} RMSE: {fold_rmse:.3f} | Count batches: {i+1}")
# calculate average of the fold rmse (accumulated rmse / number of batches)
tot_rmse += fold_rmse/(i+1)
print("--------------------------------------------------")
print(f"Avg. RMSE: {round(tot_rmse/5, ndigits=2)}")
Output:
FOLD 1
Sum of fold 1 RMSE: 1973.577 | Count batches: 102
--------------------------------------------------
FOLD 2
Sum of fold 2 RMSE: 2887.992 | Count batches: 102
--------------------------------------------------
FOLD 3
Sum of fold 3 RMSE: 1788.802 | Count batches: 102
--------------------------------------------------
FOLD 4
Sum of fold 4 RMSE: 2143.422 | Count batches: 102
--------------------------------------------------
FOLD 5
Sum of fold 5 RMSE: 2420.987 | Count batches: 102
--------------------------------------------------
Avg. RMSE: 21.99