How to change Spatial Transformer Network to support different image sizes


(Vlad) #1

Following the tutorial:
https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html

What would I need to do to adapt the STN network to work with specific B=batch size, H=image height, W=image width, C= channels? I assume B is handled by .view(-1, …) but what about the remaining H, W, C values?

Thank you and my apologies if this is a bit of a beginner question! :slight_smile:


#2

You mean the view() op in in stn() and forward()?
If you change the number of input channels, height and width for your input, you would need to adapt the in_features for the linear layers as well or alternatively use adaptive pooling layers to get your desired output size.

As you can see self.localization returns an output of [batch_size, 10, 3, 3]. That are exactly the in_features of self.fc_loc.

The easiest way would be to add print statements into your forward, use your new image shapes and just print out the new shape returned from the layers before the linear layers.