Cross-validation with PyTorch

Hi all,
as per title I am trying to implement cross-validation in PyTorch; I know that skorch lets you run PyTorch code in Sklearn functions but I am still curious to know how this could be done with in simple PyTorch code.

I looked at other post and put together what I think could be an ok way to do this, but I would love to know other people’s opinion and if this looks correct:

X = main_df.drop(columns=['target'])
y = main_df['target']

X_train_t = torch.tensor(X.values, dtype=torch.float32)
y_train_t = torch.tensor(y.values, dtype=torch.float32)

cv_dataset = TensorDataset(X_train_t, y_train_t)
class NetModel(nn.Module):
    def __init__(self, in_count):
        super(NetModel, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_count, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
            )

    def forward(self, x):
        out = self.layers(x)
        return out
num_epochs = 1
kfold = KFold(n_splits=5, shuffle=True, random_state=42)
tot_rmse = 0.0

for fold, (train_ids, test_ids) in enumerate(kfold.split(cv_dataset)):

    print(f'FOLD {fold+1}')

    train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
    test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)

    train_loader = DataLoader(cv_dataset, batch_size=5, sampler=train_subsampler)
    test_loader = DataLoader(cv_dataset, batch_size=5, sampler=test_subsampler)

    # initialize model
    model = NetModel(X_train_t.shape[1])
    criterion = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)

    for epoch in range(num_epochs):
        current_loss = 0.0

        model.train()
        for i, data in enumerate(train_loader, 0):

            inputs, targets = data
            targets = targets.reshape((targets.shape[0], 1))

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

    fold_rmse = 0.0
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):

            inputs, targets = data
            targets = targets.reshape((targets.shape[0], 1))

            outputs = model(inputs)

            rmse = root_mean_squared_error(targets, outputs)

            # accumulate rmse of the fold
            fold_rmse += rmse

    print(f"Sum of fold {fold+1} RMSE: {fold_rmse:.3f}  |  Count batches: {i+1}")

    # calculate average of the fold rmse (accumulated rmse / number of batches)
    tot_rmse += fold_rmse/(i+1)
    print("--------------------------------------------------")
print(f"Avg. RMSE: {round(tot_rmse/5, ndigits=2)}")

Output:

FOLD 1
Sum of fold 1 RMSE: 1973.577  |  Count batches: 102
--------------------------------------------------
FOLD 2
Sum of fold 2 RMSE: 2887.992  |  Count batches: 102
--------------------------------------------------
FOLD 3
Sum of fold 3 RMSE: 1788.802  |  Count batches: 102
--------------------------------------------------
FOLD 4
Sum of fold 4 RMSE: 2143.422  |  Count batches: 102
--------------------------------------------------
FOLD 5
Sum of fold 5 RMSE: 2420.987  |  Count batches: 102
--------------------------------------------------
Avg. RMSE: 21.99