The issue where outputs are clustered in a specific range in age prediction from face image

In the problem of predicting age from facial images, I’ve noticed that the output predictions are mostly clustered around the late 20s. I’ve tried adjusting the learning rate, weight decay, batch size, data augmentation, and changing the loss function (commonly MAE, MSE, Huber), as well as deepening the classifier part of the model, but none of these approaches have been effective. Can you help identify what might be the issue? Should I try a wider variety of approaches?

ex) labels = tensor([14., 42., 48., 27., 11., 15., 49., 28., 45., 24., 85., 45., 19., 32.,
48., 42., 18., 63., 21., 46., 48., 10., 50., 43., 52., 35., 10., 43.,
23., 50., 72., 69.], device=‘cuda:0’)
predictions = tensor([28.5521, 28.4822, 28.5965, 28.3632, 28.5191, 28.3275, 28.5688, 28.4773,
28.4645, 28.2645, 28.1887, 28.4816, 28.2675, 28.0921, 28.2920, 28.5518,
28.4913, 28.4839, 28.6043, 28.5412, 28.5061, 28.4903, 28.5181, 28.2907,
28.4942, 28.5645, 28.1938, 28.3621, 28.3436, 28.4609, 28.5898, 28.5448],
device=‘cuda:0’, grad_fn=)

import torch
import time
import random
import sys
import os
import yaml
import pandas as pd
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.sampler import SubsetRandomSampler
from torch.optim import lr_scheduler
from tqdm import tqdm
import torch, gc
from torchvision.transforms import GaussianBlur
import numpy as np
from torchvision.transforms import InterpolationMode

from Project.utils.face_dataset import FaceImageDataset
from Project.models.efficient_net import EfficientNet
from Project.models.vae import VAE

import ssl
ssl._create_default_https_context = ssl._create_stdlib_context

os.environ[‘TF_CPP_MIN_LOG_LEVEL’] = ‘2’
os.environ[‘TF_ENABLE_ONEDNN_OPTS’] = ‘0’

def main():
gc.collect()
torch.cuda.empty_cache()

if len(sys.argv) >= 2:
    params_filename = sys.argv[1]
    print(sys.argv)
else:
    params_filename = '../config/efficient.yaml'

with open(params_filename, 'r', encoding="UTF8") as f:
    params = yaml.safe_load(f)

# if 'random_seed' in params:
#     seed = params['random_seed']
#     random.seed(seed)
#     torch.manual_seed(seed)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed_all(seed)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)
torch.backends.cudnn.benchmark = True

# -------------------------------------------- DataLoader --------------------------------------------

dataset_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_data = FaceImageDataset('../data/Training', transform=dataset_transform)
val_data = FaceImageDataset('../data/Validation', transform=dataset_transform)

train_loader = torch.utils.data.DataLoader(train_data, params['batch_size'], shuffle=True)
dev_loader = torch.utils.data.DataLoader(val_data, params['batch_size'], shuffle=False)

# -------------------------------------------- Training Tools --------------------------------------------

model = EfficientNet().to(device)
n = sum(p.numel() for p in model.parameters())
print(f'The Number of Parameters of {model.__class__.__name__} : {n:,}')

criterion_gender = nn.CrossEntropyLoss()
criterion_age = nn.HuberLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['l2_reg_lambda'])
decay_step = [20000, 32000]
step_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=decay_step, gamma=0.1)

timestamp = str(int(time.time()))
out_dir = os.path.abspath((os.path.join("../scripts/runs", f'{timestamp}_{model.__class__.__name__}')))
checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
summary_dir = os.path.join(out_dir, "summaries")

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

writer = SummaryWriter(summary_dir)

# ------------------------------------------- Training -------------------------------------------

cnt = 0
start_time = time.time()
lowest_val_loss = 999
global_steps = 0
train_age_errors = []
print('========================================')
print("Training...")
for epoch in range(params['max_epochs']):
    train_loss = 0
    train_correct_cnt = 0
    train_batch_cnt = 0
    model.train()
    for img, gender, age in train_loader:
        img = img.to(device)
        gender = gender.to(device)
        age = age.type(torch.FloatTensor).to(device)

        output_gender, output_age = model.forward(img)
        loss_gender = criterion_gender(output_gender, gender)
        loss_age = criterion_age(output_age, age)
        loss = loss_age + loss_gender

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        step_lr_scheduler.step()

        train_loss += loss
        train_batch_cnt += 1

        _, gender_pred = torch.max(output_gender, 1)
        gender_pred = gender_pred.squeeze()
        train_correct_cnt += int(torch.sum(gender_pred == gender))

        age_errors = torch.abs(output_age.squeeze() - age)
        train_age_errors.append(age_errors.detach())

        batch_total = gender.size(0)
        batch_correct = int(torch.sum(gender_pred == gender))
        batch_gender_acc = batch_correct / batch_total

        writer.add_scalar("Batch/Loss", loss.item(), global_steps)
        writer.add_scalar("Batch/Acc", batch_gender_acc, global_steps)
        writer.add_scalar("LR/Learning_rate", step_lr_scheduler.get_last_lr()[0], global_steps)

        global_steps += 1
        if (global_steps) % 100 == 0:
            print('Epoch [{}], Step [{}], Loss: {:.4f}'.format(epoch+1, global_steps, loss.item()))

    if cnt < 3:
        print(age, output_age.squeeze())
        cnt += 1

    train_gender_acc = train_correct_cnt / len(train_data) * 100
    train_mae_age = torch.cat(train_age_errors).mean()
    train_ave_loss = train_loss / train_batch_cnt
    training_time = (time.time() - start_time) / 60
    writer.add_scalar("Train/Loss", train_ave_loss, epoch)
    writer.add_scalar("Train/Gender Acc", train_gender_acc, epoch)
    writer.add_scalar("Epoch/Train MAE Age", train_mae_age.item(), epoch)
    print('========================================')
    print("epoch:", epoch + 1, "/ global_steps:", global_steps)
    print("training dataset average loss: %.3f" % train_ave_loss)
    print("training_time: %.2f minutes" % training_time)
    print("learning rate: %.6f" % step_lr_scheduler.get_last_lr()[0])

# -------------------------------------- Validation  --------------------------------------

    cnt = 0
    val_correct_cnt = 0
    val_loss = 0
    val_batch_cnt = 0
    val_age_errors = []
    print('========================================')
    print('Validation...')
    model.eval()
    with torch.no_grad():
        for img, gender, age in dev_loader:
            img = img.to(device)
            gender = gender.to(device)
            age = age.type(torch.FloatTensor).to(device)

            output_gender, output_age = model.forward(img)
            loss_gender = criterion_gender(output_gender, gender)
            loss_age = criterion_age(output_age, age)
            loss = loss_gender + loss_age

            val_loss += loss.item()
            val_batch_cnt += 1
            _, gender_pred = torch.max(output_gender, 1)
            gender_pred = gender_pred.squeeze()
            val_correct_cnt += int(torch.sum(gender_pred == gender))

            age_errors = torch.abs(output_age.squeeze() - age)
            val_age_errors.append(age_errors)

            if cnt < 5:
                print(age, output_age.squeeze())
            cnt += 1

    val_gender_acc = val_correct_cnt / len(val_data) * 100
    val_mae_age = torch.cat(val_age_errors).mean()
    val_ave_loss = val_loss / val_batch_cnt
    print("validation dataset gender accuracy: %.2f" % val_gender_acc)
    print("validation dataset age MAE: %.3f" % val_mae_age)
    print('========================================')
    writer.add_scalar("Val/Gender Acc", val_gender_acc, epoch)
    writer.add_scalar("Val/Age Loss", val_ave_loss, epoch)

    if val_ave_loss < lowest_val_loss:
        save_path = checkpoint_dir + '/epoch_' + str(epoch + 1) + '.pth'
        torch.save({'epoch': epoch + 1,
                    'model_state_dict': model.state_dict()},
                   save_path)

        save_path = checkpoint_dir + '/best.pth'
        torch.save({'epoch': epoch + 1,
                    'model_state_dict': model.state_dict()},
                   save_path)
        lowest_val_loss = val_ave_loss
    epoch += 1

if name == ‘main’:
main()

----------------------- Model -----------------------
from torch import nn
from torchvision import models
import torch.nn.functional as F

class EfficientNet(nn.Module):
def init(self):
super(EfficientNet, self).init()
self.base_model = models.efficientnet_b1(pretrained=True)
num_features = self.base_model.classifier[1].in_features
self.base_model.classifier = nn.Identity()

    self.gender_classifier = nn.Linear(num_features, 2)

    self.age_regressor = nn.Linear(num_features, 1)


def forward(self, x):
    features = self.base_model(x)
    gender_output = self.gender_classifier(features)
    age_output = self.age_regressor(features)

    return gender_output, age_output

Hi Rare!

Note the labels and predictions don’t appear in any of the code you posted below,
so it’s hard to be sure what you are doing here.

I would suggest trying to train with loss_age (e.g., loss_age.backward()). The
“units” of CrossEntropyLoss and MSELoss (or HuberLoss) are different, so the
loss_gender term might be swamping things, causing loss_age to be ignored.

If training with loss_age gives you sensible age predictions, then train with a
weighted combination of the two:

        loss_gender = criterion_gender(output_gender, gender)
        loss_age = criterion_age(output_age, age)
        loss_total = loss_age + gender_weight * loss_gender
        loss_total.backward()

Then adjust gender_weight so that you get good predictions for age and gender.

If you can’t get good age predictions just training on loss_age, try overfitting a
smallish subsample of your data. If you can’t overfit, even after training “a lot”,
then start looking for bugs.

By the way, what is the mean age across your entire training dataset?

Best.

K. Frank