Resizing semantic segmentation labels


I have per pixel labels of the shape [1, H, W] that is of type torch.long (or torch.int64) that are integers representing classes for a semantic segmentation task.

When I try to resize it with torch.nn.functional.interpolate it says that it is not implemented for torch.long.

I was trying to convert the tensor to uint8, then resize, then convert to torch.long, the only issue with this is that it adds NANs to the mask somehow.

Any ideas? What is the standard way to do this?


See this if you are doing resize in data loading. Otherwise I often convert the segmentation mask to float and then torch.nn.functional.interpolate(mask, $TARGET_SHAPE, mode='nearest') and finally convert it back to long.

Sweet! That works! I ended up using uint8 instead of floats because interpolate is implemented for that.