CRF loss for semantic segmentation

I am doing semantic segmentation and was wondering if there is a method in PyTorch that will allow me to compute the CRF loss shown below? I am not trying to do inference. I just want to compute the loss based on the unary and pairwise terms.

I could do it myself. Replicate the output 8 times, shift the pixels accordingly and compute the difference to determine if the labels are similar but I don’t know if that would be efficient.

I think you want to give more context than that. I assume that phi is your “ground truth loss” (x being the ground truth or some such information) and psi the consistency loss.
Now psi is rather underspecified in the formula. From your description, I looks like that i,j iterates over all pairs of adjacent pixels (of which there would be 8 if you have axis-parallel and diagonal neighbors in 2D or 8 if you have axis-parallel only in 3D).
Quite likely, it is not terribly efficient but neither terribly inefficient to do the slicing manually ( (y[:, :, 1:, 1:] - y[:, :, :-1, :-1]).abs().sum() and friends).

Best regards

Thomas

1 Like

Sorry, here is a description of the terms used

But you are right, and that is exactly what I am trying to do. And thank you, doing it your way for each of the 8 neighbours

y[:, :, 1:, 1:] - y[:, :, :-1, :-1]).abs().sum()

Looks cleaner and more efficient than my initial plan to use torch.roll for each neighbour.