Permute inside nn.Sequential

I have a convolutional layer defined inside a sequential model and would like to permute its output. Does PyTorch have an equivalent to ´´x = x.permute(0, 2, 1)´´ that can be used inside ´´nn.Sequential´´?

I’m not aware of a built-in module, but you can easily create your own Permute layer:

class Permute(nn.Module):
    def __init__(self, dims):
        super().__init__()
        self.dims = dims
        
    def forward(self, x):
        x = x.permute(self.dims)
        return x

model = nn.Sequential(
    nn.Linear(10, 20),
    Permute(dims=[0, 2, 1]),
    nn.Linear(10, 20),
)

x = torch.randn(1, 10, 10)
out = model(x)
print(out.shape)
# torch.Size([1, 20, 20])
1 Like