I have a weighted categorical cross entropy function implemented in tensorflow/Keras
# https://gist.github.com/wassname/ce364fddfc8a025bfab4348cf5de852d
def weighted_categorical_crossentropy(weights):
weights = K.variable(weights)
def loss(y_true, y_pred):
# Scale predictions so that the class probs of each sample sum to 1
y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
# Clip to prevent NaN's and Inf's
y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
loss = y_true * K.log(y_pred)# * weights
loss = -K.sum(loss, -1)
return loss
return loss
Which I have translated to PyTorch (though currently just using it as a metric for other reasons)
class WeightedCategoricalCrossentropy(nn.Module):
eps = 1e-10
def __init__(self, weights=None):
self.weights = weights
super(WeightedCategoricalCrossentropy, self).__init__()
# Here y_pred is one-encoded; network has not yet had softmax applied
def forward(self, y_pred, y_true):
y_pred = y_pred.permute(0, 2, 3, 1)
y_pred = F.softmax(y_pred, dim=-1)
y_pred /= torch.sum(y_pred, dim=-1, keepdim=True)
y_pred = torch.clip(y_pred, self.eps, 1 - self.eps)
loss = y_true * torch.log(y_pred) #* self.weights
loss = -torch.sum(loss, dim=-1)
return torch.mean(loss).item()
If I comment out the weighting (as done above) and compare to a torch.nn.CrossEntropyLoss
without weighting weight=None
, I can confirm I am calculating the same value
Epoch (1/25) (18s) |##################################################| 100.0% train - loss: 2.1662 acc: 0.1741 wcce: 2.1662, val - loss: 1.8247 acc: 0.0758 wcce: 1.8247
Though, if I apply my class weights by uncommenting I see that I do not calculate the same value as torch.nn.CrossEntropyLoss
with the same weights applied
Epoch (1/25) (18s) |##################################################| 100.0% train - loss: 2.2266 acc: 0.1693 wcce: 1.2113, val - loss: 1.8311 acc: 0.0736 wcce: 1.0105
What is the difference between how the weighting is applied in the custom implementation versus how it is applied in the built-in method?
For reference, here is my training implementation
import time
import numpy as np
import torch.optim as optim
from util import progress
from hsnet.losses import *
from torch.utils.data import DataLoader
# Elsewhere
def get_loss(loss_name, weights=None):
if loss_name == 'cce':
return nn.CrossEntropyLoss(weight=torch.FloatTensor(weights).cuda())
elif loss_name == 'wcce':
return WeightedCategoricalCrossentropy(weights=torch.FloatTensor(weights).cuda())
else:
raise RuntimeError('Unrecognized loss function!')
def train(model, train_dataset, val_dataset, batch_size, epochs, lrate, loss, weights):
print('> Training...')
torch.autograd.set_detect_anomaly(True)
print(' Preparing.')
opt = optim.Adam(model.parameters(), lr=lrate)
# opt = optim.SGD(model.parameters(), lr=0.000001, momentum=0.9)
# opt = optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)
metric_loss = get_loss(loss, weights)
metrics = [('acc', get_loss('acc')), ('wcce', get_loss('wcce', weights))]
val_data = DataLoader(val_dataset, batch_size=batch_size)
train_data = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
print(' Begin training...')
history = {
'train': {name: [] for name in ['loss'] + [m for m, _ in metrics]},
'val': {name: [] for name in ['loss'] + [m for m, _ in metrics]}
}
for e in range(epochs):
start_time = time.perf_counter()
# -------------------------
# Training pass
# -------------------------
model.train()
#model.train(mode=True)
train_batch_history = {name: [] for name in ['loss'] + [m for m, _ in metrics]}
num_train_batch = len(train_data)
for b_ind, b_sample in enumerate(train_data):
# Get the data for this batch and move it to the GPU
x, y, y_cat = b_sample['x'], b_sample['y'], b_sample['y_cat']
x = torch.FloatTensor(x).cuda()
y = torch.LongTensor(y).cuda()
y_cat = torch.FloatTensor(y_cat).cuda()
# Perform the actual training
model.zero_grad()
pred = model(x)
loss_value = metric_loss(pred, y)
loss_value.backward()
opt.step()
# Calculate and save the batch loss and metrics
train_batch_history['loss'].append(loss_value.item())
for i, (name, metric) in enumerate(metrics):
if name == 'wcce':
train_batch_history[name].append(metric(pred, y_cat))
else:
train_batch_history[name].append(metric(pred, y))
# Report training progress
suffix = 'train - loss: ' + '{0:.4f}'.format(np.average(train_batch_history['loss']))
for i, (name, metric) in enumerate(metrics):
suffix += ' ' + name + ': ' + '{0:.4f}'.format(np.average(train_batch_history[name]))
progress(b_ind, num_train_batch,
prefix='Batch (' + str(b_ind + 1) + '/' + str(num_train_batch) + ') ',
suffix=suffix, decimals=1, length=50, fill='#')
# Save the training history loss and metrics
history['train']['loss'].append(sum(train_batch_history['loss']) / num_train_batch)
for (name, metric) in metrics:
history['train'][name].append(sum(train_batch_history[name]) / num_train_batch)
# -------------------------
# Validation pass
# -------------------------
model.eval()
#model.train(mode=False)
val_batch_history = {name: [] for name in ['loss'] + [m for m, _ in metrics]}
num_val_batch = len(val_data)
with torch.no_grad():
for b_ind, b_sample in enumerate(val_data):
# Get the data for this batch and move it to the GPU
x, y, y_cat = b_sample['x'], b_sample['y'], b_sample['y_cat']
x = torch.FloatTensor(x).cuda()
y = torch.LongTensor(y).cuda()
y_cat = torch.FloatTensor(y_cat).cuda()
# Run the validation data through the model
pred = model(x)
loss_value = metric_loss(pred, y)
# Calculate and save the batch loss and metrics
val_batch_history['loss'].append(loss_value.item())
for i, (name, metric) in enumerate(metrics):
if name == 'wcce':
val_batch_history[name].append(metric(pred, y_cat))
else:
val_batch_history[name].append(metric(pred, y))
# Save the training history loss and metrics
history['val']['loss'].append(sum(val_batch_history['loss']) / num_val_batch)
for (name, metric) in metrics:
history['val'][name].append(sum(val_batch_history[name]) / num_val_batch)
# Report the epoch results
suffix = 'train - loss: ' + '{0:.4f}'.format(history['train']['loss'][-1])
for (name, metric) in metrics:
suffix += ' ' + name + ': ' + '{0:.4f}'.format(history['train'][name][-1])
suffix += ', val - loss: ' + '{0:.4f}'.format(history['val']['loss'][-1])
for (name, metric) in metrics:
suffix += ' ' + name + ': ' + '{0:.4f}'.format(history['val'][name][-1])
progress(1, 1, prefix='Epoch (' + str(e + 1) + '/' + str(epochs) + ') ({0:.0f}s)'.format(time.perf_counter() - start_time),
suffix=suffix, decimals=1, length=50, fill='#')
return model, history