Validation loss hits minimum almost immediately

I’m currently having an issue with my model where overfitting starts almost immediately. I’m currently fine-tuning a resnet50 model:

model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(num_ftrs, 2))
model = model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer_ft = torch.optim.SGD(model.parameters(), 
                               lr=0.01, 
                               momentum=0.9, 
                               weight_decay=0.0001)

And my training loop looks like:

def train_model(model, dataloader, criterion, optimizer, num_epochs=25):

    for epoch in range(num_epochs):
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloader[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    preds = torch.argmax(outputs, dim=1)
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloader[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloader[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))    
    
    return model

I’ve tried adjusting weight decay, adding dropout in the FC layer, and lowering the learning rate, but no matter what, the logs usually look just like this:

Epoch 0/49
----------
train Loss: 0.5754 Acc: 0.7006
val Loss: 0.5206 Acc: 0.7448

Epoch 1/49
----------
train Loss: 0.5089 Acc: 0.7478
val Loss: 0.5259 Acc: 0.7509

Epoch 2/49
----------
train Loss: 0.4815 Acc: 0.7655
val Loss: 0.4894 Acc: 0.7622

Epoch 3/49
----------
train Loss: 0.4608 Acc: 0.7794
val Loss: 0.4889 Acc: 0.7633

Epoch 4/49
----------
train Loss: 0.4433 Acc: 0.7893
val Loss: 0.4955 Acc: 0.7667

Epoch 5/49
----------
train Loss: 0.4258 Acc: 0.8001
val Loss: 0.5082 Acc: 0.7613

Every time it starts overfitting around 3-4 epochs, and I’ve let it run up until 50. Has anyone had any experience with this and could offer any advice?

Also, my training set is about 193,000 images, and my validation set is about 21,000.