How can I pad all 4 dimensions of an NCHW tensor?

For example, say I have a tensor of size (3,5,8,8):

x = torch.ones(3,5,8,8)

And I want to pad the dimensions by separate padding values: N by 2, C by 1, H by 3, and W by 3.

How can I accomplish this?

You can use F.pad:

x = torch.ones(3,5,8,8)
out = F.pad(x, (1, 2, 1, 2, 0, 1, 1, 1))
print(out.shape)
> torch.Size([5, 6, 11, 11])

The docs explain the usage of the pad sizes.

@ptrblck What if I want to use reflection padding? I tried to write a function for it, but there are incorrectly repeated values, and the first dimension is wrong in terms of size:

def pad_reflective_a4d(x: torch.Tensor, padding: List[int]) -> torch.Tensor:
    """
    Reflective padding for all 4 dimensions of an NCHW tensor
    """

    assert x.dim() == 4
    assert len(padding) == 8

    # Pad width
    if padding[0] != 0:
        x = torch.cat([x, x.flip([3])[..., 0 : padding[0]]], dim=3)
    if padding[1] != 0:
        x = torch.cat([x.flip([3])[..., -padding[1] :], x], dim=3)

    # Pad height
    if padding[2] != 0:
        x = torch.cat([x, x.flip([2])[..., 0 : padding[2], :]], dim=2)
    if padding[3] != 0:
        x = torch.cat([x.flip([2])[..., -padding[3] :, :], x], dim=2)

    # Pad channels
    if padding[4] != 0:
        x = torch.cat([x, x.flip([1])[:, 0 : padding[4]]], dim=1)
    if padding[5] != 0:
        x = torch.cat([x.flip([1])[:, -padding[5] :], x], dim=1)

    # Pad batch
    if padding[6] != 0:
        x = torch.cat([x, x.flip([0])[0 : padding[6]]], dim=0)
    if padding[7] != 0:
        x = torch.cat([x.flip([0])[-padding[7] :], x], dim=0)
    return x

x = torch.arange(1,(3*4*4)+1).view(1, 3, 4, 4).float()
out = pad_reflective_a4d(x, (2,2, 2,2, 2,2, 2,2))