Convert n4hw:uint8 to n1hw:int for image processing?

I’m working with images, and sometimes it’s useful to view them as one int per pixel, and sometimes as four bytes per pixel.

Is there a way to convert between these two views in pytorch?

Specifically I’m looking to go from four bytes per pixel to one int per pixel. Ideally while re-using the underlying memory.


Would view work?

>>> a = torch.tensor([1, 2, 3, 4]).to(torch.uint8)
>>> a
tensor([1, 2, 3, 4], dtype=torch.uint8)
>>> a.view(torch.int32)
tensor([67305985], dtype=torch.int32)
>>> a = torch.tensor([4, 3, 2, 1]).to(torch.uint8)
>>> a.view(torch.int32)
tensor([16909060], dtype=torch.int32)

That works excellently, thanks @ptrblck !

For future readers: you have to rearrange your data so that it’s channels last, and you have to have four channels for this to work. E.g:

a = torch.ones((1, 4, 5, 5)).to(torch.uint8).permute(0, 2, 3, 1).contiguous()  # shape: [1, 5, 5, 4], NHWC
b = a.view(torch.int32)  # shape: [1, 5, 5, 1] NHWC

If you mess up your channel dimension or the tensor isn’t contiguous then you get all sorts of quite sensible error messages, e.g.:

  • RuntimeError: self.size(-1) must be divisible by 4 to view Byte as Int (different element sizes), but got 5

  • RuntimeError: self.stride(-1) must be 1 to view Byte as Int (different element sizes), but got 25