Tranpose axes dont match error


so im trying to visualize the ouput of my spatial transformer component and when i try to apply a transpose to the output of the stn which is of shape (1, 50, 22, 22) by applying a transpose of (1,2,0) i get an axes dont match error. im trying to follow this demo: as to how im going to visualize the stn.

Any help is greatly appreciated

I suppose you are trying to call transpose on the numpy array not on your tensor right?
Assuming this, you have to give all 4 dimensions to the transpose call, since your array has 4 dimensions.

This works:

np.random.randn(1, 50, 22, 22).transpose(1, 2, 0, 3)

while this throws your error:

np.random.randn(1, 50, 22, 22).transpose(1, 2, 0)

What do the dimensions stand for in your example, [channel, batch_size, width, height]?

EDIT: In the STN example, they are calling make_grid on the images, which returns one image ([channel, w, h]) containing a grid of the “subimages”. Could you try to use it, too?

i see i will try the change you suggested! it makes much more sense adding the fourth dimension, the four dimensions are batch, channel (50 because the input is the result of a previous convolution layer) and then height, width

ill try what you suggested and if it doesnt seem successful, ill try what the stn example shows as well. thank you @ptrblck

actually just another question, since my input into the stn is the result of a convolution, my output from the stn will also have 50 images of 20*20 since thats what the input to the stn was. so trying to visualize it as a single image wont work, right?

My question then is, is there a way to combine all 50 images into one image, or do i have to visualize all 50 of them seperately.

You could visualize each “slice” of the activation volume coming from the convolution.
I created a small code snippet:

def show_activation_volume(grid):
    grid_arr = grid.numpy().transpose(1, 2, 0)
    grid_arr -= grid_arr.min()
    grid_arr /= grid_arr.max()

x = Variable(torch.randn(1, 3, 24, 24))
conv = nn.Conv2d(3, 50, 3)
output = conv(x)
output =, 1)
grid = make_grid(output)

Is this what you are looking for?

thank you that makes logical sense, i will try that!