Hi, I am wondering that how can I use Pearson Correlation as the loss function in PyTorch?
Just code it directly. Assuming a batch of N outputs
x = output
y = target
vx = x - torch.mean(x)
vy = y - torch.mean(y)
cost = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2)))
or you could probably use
cost = vx * vy * torch.rsqrt(torch.sum(vx ** 2)) * torch.rsqrt(torch.sum(vy ** 2)))
Where the rsqrt() function is just the reciprocal of the square root.
@ajbrockThanks a lot! I thought I need to implement something like autograd or a customized loss function class.
Seems still have problem, my implementation is like the following:
def train(**kwargs):
#torch.manual_seed(100) # 10, 100, 666,
opt.parse(kwargs)
vis = Visualizer(opt.env)
# step1: configure model
model = getattr(models, opt.model)()
if opt.load_model_path:
model.load(opt.load_model_path)
if opt.use_gpu:
model.cuda()
# step2: load data
train_data = STSDataset(opt.train_data_path)
val_data = STSDataset(opt.train_data_path)
train_dataloader = DataLoader(train_data, opt.batch_size,
shuffle=True,
num_workers=opt.num_workers)
val_dataloader = DataLoader(val_data, opt.batch_size,
shuffle=False,
num_workers=opt.num_workers)
torch.save(train_data.X, opt.train_features_path)
torch.save(train_data.y, opt.train_targets_path)
# step3: set criterion and optimizer
criterion = torch.nn.MSELoss()
lr = opt.lr
optimizer = torch.optim.Adam(model.parameters(), lr=lr,
weight_decay=opt.weight_decay)
#optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
# step4: set meters
loss_meter = meter.MSEMeter()
previous_loss = 1e100
# train
for epoch in range(opt.max_epoch):
loss_meter.reset()
for ii, (data, label) in enumerate(train_dataloader):
# train model on a batch data
input = Variable(data)
target = Variable(torch.FloatTensor(label.numpy()))
if opt.use_gpu:
input = input.cuda()
target = target.cuda()
optimizer.zero_grad()
score = model(input)
#loss = criterion(score, target) # use MSE loss function
vx = score - torch.mean(score)
vy = target - torch.mean(target)
loss = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2))) # use Pearson correlation
loss.backward()
optimizer.step()
# update meters and visualize
loss_meter.add(score.data, target.data)
if ii % opt.print_freq == opt.print_freq - 1:
vis.plot('loss', loss_meter.value())
# enter debug mode
if os.path.exists(opt.debug_file):
import ipdb
ipdb.set_trace()
# save model for each epoch
#model.save()
# validate and visualize
val_mse, val_pearsonr = val(model, val_dataloader)
vis.plot('val_mse', val_mse)
vis.plot('pearson', val_pearsonr)
vis.log("epoch:{epoch},lr:{lr},\
loss:{loss},val_mse:{val_mse},val_pearson:{val_pearson}".format(
epoch=epoch,
lr=lr,
loss=loss_meter.value(),
val_mse=str(val_mse),
val_pearson=str(val_pearsonr)))
# update learning rate
if loss_meter.value() > previous_loss:
lr = lr * opt.lr_decay
for param_group in optimizer.param_groups:
param_group['lr'] = lr
previous_loss = loss_meter.value()
and when I check the output, the MSE, pearson correlation are all NaN.
That could be any one of a million things, and there’s also no guarantee that pearson’s R is a good loss function to optimize, just FYI. You might want to consider dividing by the batch size (I take sums, but you could take means), looking into exactly what torch.mean is calculating (if your data has trailing dimensions then you need to account for that), what’s your model, how is it initialized, what’s your data, etc. All the typical ML stuff.
Yeah you are right! Thanks for your suggestions, I shall look into all those potential problems to find out why.
Did you look into the errors and find a solution for the NaNs?
I added an error 1e-6 to the cost and created a new class for training. However the value of the loss was always 0 while computing pcc. Please help!!! Here follows my codes.
class CCCLoss(torch.nn.Module):
def __init__(self, eps=1e-6):
super(CCCLoss, self).__init__()
self.eps = eps
def forward(self, y_true, y_hat):
y_true_mean = torch.mean(y_true)
y_hat_mean = torch.mean(y_hat)
y_true_var = torch.var(y_true)
y_hat_var = torch.var(y_hat)
y_true_std = torch.std(y_true)
y_hat_std = torch.std(y_hat)
vx = y_true - torch.mean(y_true)
vy = y_hat - torch.mean(y_hat)
pcc = torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2) + self.eps) * torch.sqrt(torch.sum(vy ** 2) + self.eps))
ccc = (2 * pcc * y_true_std * y_hat_std) / \
(y_true_var + y_hat_var + (y_hat_mean - y_true_mean) ** 2)
ccc = 1 - ccc
return ccc
If one of the vectors you are computing correlation with is constant (all values are equal), then the correlation computation will have a division by zero. Deal with this in advance so that you don’t have a torch.sqrt(0)
situation.
I know it’s an old topic, but I had the same question and get there. I think I’ve a much simpler and stable workaround to share.
Knowing the Person correlation is a “centered version” of the cosine similarity, you can simply get it with:
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
pearson = cos(x1 - x1.mean(dim=1, keepdim=True), x2 - x2.mean(dim=1, keepdim=True))
Plus you benefit from the stability of the pytorch implementation of the cosine similarity, the eps
term avoiding any division by 0. And dim
let you choose the dimension to where the Pearson correlation is computed.
Hope that help someone.
Just to clarify to use the above cosineSimilarity as a. loss we need to multiply it by -1. Right?
Note:
Use torch.norm(x)
instead of torch.sqrt(torch.sum(x**2))
to avoid NaN
gradients
Well although we have torch.corrcoef() now, I found that its really slow. I
m training a modified MLP on 3060, spending minutes for 10 correlations of 4096*4096 matrices, which is substantial slower than the sum of any other calculation. Thus I`m going to try other cost functions.
10 * 4096 * 4096 = 167,772,160.
depending on what exactly is going on there you might be underestimating how long your calculations should be taking. are you trying to crunch at least 160 million floats in every forward pass? that’s certainly not a cheap operation. are you running out of CPU memory?