How to get output dimension within nn.Sequential()?

I am writing a flexible VAE that takes in inputs of varying sizes. This is the code that I have so far:

class Vanilla_3DVAE(nn.Module):
    def __init__(self):
        super(Vanilla_3DVAE, self).__init__()
        # TODO: Parametrize the grid dimensions and channel values
        self.nn_1 = nn.Sequential(
            nn.Conv3d(8, 16, 4, 2, 1 ),        # -4 from dimensions
            nn.ReLU(),
            nn.AvgPool3d(2),            # /2
            nn.BatchNorm3d(16),
            nn.Conv3d(16, 32, 4, 2, 1 ),       # -4
            nn.ReLU(),
            nn.AdaptiveAvgPool3d( 1 ),  # dimension -> 1x1x1
            nn.BatchNorm3d(32))
        self.nn_2 = nn.Sequential(
                                  nn.Linear(32, 16),    #  32x1x1x1
                                  nn.ReLU()
                                )
        self.mu_fc = nn.Linear(16,16)
        self.logvar_fc = nn.Linear(16,16)
        
        self.decoder_nn = nn.Sequential(
                                          nn.ConvTranspose3d(16,8,4,2,0),
                                          nn.BatchNorm3d(8),
                                          nn.ConvTranspose3d(8,16,4,2,0),
                                          nn.BatchNorm3d(16),
                                          nn.ConvTranspose3d(16,8,4,2,1)
                                          )

Since the inputs will have different dimensions, I’d like to keep track of the dimension size right before nn.AdaptiveAvgPool3d( 1 ), storing this in a variable. I don’t have access to the input dimension size in this .py file. What is the best way to do this?

As far as I know, there are two ways of doing this.

  1. If possible I would separate the AdaptiveAvgPool3d from the self.nn_1
    Like so:
self.nn_1 = nn.Sequential(
            nn.Conv3d(8, 16, 4, 2, 1 ),        # -4 from dimensions
            nn.ReLU(),
            nn.AvgPool3d(2),            # /2
            nn.BatchNorm3d(16),
            nn.Conv3d(16, 32, 4, 2, 1 ),       # -4
            nn.ReLU())
self.adaptive_pool = nn.Sequential(
            nn.AdaptiveAvgPool3d( 1 ),  # dimension -> 1x1x1
            nn.BatchNorm3d(32))

And then in your forward:

def forward(self, x):
        x = self.nn_1(x)
        shape_before_pooling = x.shape
        x = self.adaptive_pool(x)
        [...]
        return x, shape_before_pooling
  1. If for some reason it is not possible for you to separate AdaptiveAvgPool3d from the self.nn_1 you can do this:
shapes = list(range(2))

def get_shapes(self, input, output):
    # input is a tuple of packed inputs
    # output is a Tensor. output.data is the Tensor we are interested
    shapes[0] = input[0].shape
    shapes[1] = output.shape

net = Vanilla_3DVAE()
net.nn_1[6].register_forward_hook(get_shapes)
print(shapes)

Thank you :slight_smile: Your help is pushing me in the right direction. Since I’m planning to use the variable that holds the pre-pooling shape in the decoder, would something like this work instead? :

self.nn_1 = nn.Sequential( ... )
self.pre-pool_shape = x.shape
self.nn_adpative_pool = nn.Sequential(...)
self.nn_2 = nn.Sequential(....)

self.mu & logvar_fc layers

self.decoder_nn = nn.Sequential( ... <using variable pre-pool_shape> ...)

Or if using your first method, the dimensions are stored in shape_before_pooling which is in forward(). How can I use this variable in self.decoder_nn? Right now, my forward, encode and decode looks like:

 def encode(self, x):
        encoded = self.nn_1(x)
        encoded_flat = self.flatten(encoded)
        encoded_flat = self.nn_2(encoded_flat)
        mu = self.mu_fc(encoded_flat)
        logvar = self.logvar_fc(encoded_flat)
        return mu, logvar

    def decode(self, z):
        z_4d = self.restore_dim(z)
        return self.decoder_nn(z_4d)

    def forward(self, x):
        mu,logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        decoded = self.decode(z)
        return (decoded, z, mu, logvar)

If I make the variable within forward() as you have, create a new parameter in decode() as such: decode(self, z, prePoolShape), and inside that function, do return self.decoder_nn(z_4d, prePoolShape), would that work?

Sorry if this seems really beginner. I haven’t had too much experience with VAEs and neural networks in general, and again I really appreciate your help :slight_smile:

Well you can have the variable in the decorator but you want to obviously want to put the value into it in the forward function. Or in your case in the encode function.
This should work:
You have your decorator like I said + the shape_before_pooling variable you wanted

class Vanilla_3DVAE(nn.Module):
    def __init__(self):
        super(Vanilla_3DVAE, self).__init__()
        self.nn_1 = nn.Sequential(
            nn.Conv3d(8, 16, 4, 2, 1 ),        # -4 from dimensions
            nn.ReLU(),
            nn.AvgPool3d(2),            # /2
            nn.BatchNorm3d(16),
            nn.Conv3d(16, 32, 4, 2, 1 ),       # -4
            nn.ReLU())
        self.adaptive_pool = nn.Sequential(
            nn.AdaptiveAvgPool3d( 1 ),  # dimension -> 1x1x1
            nn.BatchNorm3d(32))
        self.nn_2 = nn.Sequential(....)
        self.mu_fc = (...)
        self.logvar_fc = (...)
        self.decoder_nn = (...)

        self.pre_pool_shape = None        #this is new

And then have you forward, encode and decode looks like this:

    def encode(self, x):
        encoded = self.nn_1(x)
        self.shape_before_pooling = encoded.shape        #this is new
        adap_pooled = self.adaptive_pool(encoded)        #this is new
        encoded_flat = self.flatten(adap_pooled)        #this is different
        encoded_flat = self.nn_2(encoded_flat)
        mu = self.mu_fc(encoded_flat)
        logvar = self.logvar_fc(encoded_flat)
        return mu, logvar

    def decode(self, z):
        z_4d = self.restore_dim(z)
        return self.decoder_nn(z_4d)

    def forward(self, x):
        mu,logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        decoded = self.decode(z)
        return (decoded, z, mu, logvar)

I am fairly certain this should work.

1 Like