Pytorch training with unbalanced class sizes while validation with balanced classes isn't working

I’ve been coding a pytorch CNN on a binary classification task where my dataset is unbalanced – the ratio of the two classes is 30:1. Within the training loop, every 5 epochs, I calculate the loss on a small set of validation data to display along with the loss on the much larger training set to monitor progress and help tune the hyperparameters.

At first, to keep things simple, I only used a subset of the training data so that the two classes would be balanced. That was working and I saw some moderate improvements in loss in both training and validation before overfitting set in.

In order to improve training, I then tried to use all of the training data and switched from the original loss function torch.nn.BCELoss to torch.nn.BCEWithLogitsLoss(pos_weight=30.0), keeping in mind to drop the final torch.sigmoid transform. Since the validation data continued to be balanced, I used torch.nn.BCEWithLogitsLoss(pos_weight=1.0) to calculate the validation loss.

With this setup, although my training loss decreases right away, my validation loss only goes up.

Here is the relevant code:

def train(epochs, optimizer, model, train_loss_fn, train_loader, test_loss_fn, test_loader):
    for epoch in range(1, epochs + 1):
        train_losses = []

        for images, classes in train_loader:
            images =
            classes =
            outputs = model(images)

            loss = train_loss_fn(torch.flatten(outputs), classes)
            train_losses.append( loss.item() )

        if (epoch == 1) | (epoch % 5 == 0):
            now =
            test_loss = test(model, test_loader, test_loss_fn)
            test_loss = np.round(test_loss, 4)
            train_loss = np.round(np.mean(np.asarray(train_losses)),4)
            lr = np.format_float_scientific(np.squeeze(get_lr(optimizer)), precision=4)
            print(f"{now}: Epoch {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss}, Learning Rate: {lr}")

def test(model, test_loader, lf):
    test_losses = []
    with torch.no_grad():
        for i, (images,classes) in enumerate(test_loader):
            images =
            classes =
            outputs = model(images)

            loss = lf(torch.flatten(outputs), torch.flatten(classes))
            test_losses.append( loss.item() )

    mean_loss = np.mean(np.asarray(test_losses))
    return mean_loss

Class_1_Weighting = torch.tensor(30.)
TRAIN_loss_fn = torch.nn.BCEWithLogitsLoss(reduction='mean', pos_weight=Class_1_Weighting)
TEST_loss_fn = torch.nn.BCEWithLogitsLoss(reduction='mean')

# train(...)

cross-posted from stackexchange