Transfer learning usage with different input size

VGG16 and Resnet require input images to be of size 224X224X3. I know my question may be stupid, but is there any chance to use these pretrained networks on a datasets with different input sizes (for example, black and white images of size 224X224X1? or images of different size, which I don’t want to resize?)

Thanks!

1 Like

In your first use case (different number of input channels) you could add a conv layer before the pre-trained model and return 3 out_channels.

For different input sizes you could have a look at the source code of vgg16. There you could perform some model surgery and add an adaptive pooling layer instead of max pooling to get your desired shape for the classifier (512*7*7).

Note that the performance of your pre-trained model might differ for different input sizes.

3 Likes

Thanks a lot ptrblck! The way to implement the first part of your answer is something like -

    model = models.vgg16(pretrained=True)
    first_conv_layer = list(nn.Conv2d(1, 3, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True))
    first_conv_layer .extend(list(model.features))  
    model.features= nn.Sequential(*first_conv_layer )  

?

Looks good with some minor tweaks.
I think you’ll get an error calling list(nn.Conv2d...).
Also, since we didn’t change the input size, we should take care to return the same spatial dimensions after the first conv layer. Using kernel_size=3 and padding=1 will result in the same size.
Here is the code:

x = torch.randn(1, 1, 224, 224)
model = models.vgg16(pretrained=False) # pretrained=False just for debug reasons
first_conv_layer = [nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)]
first_conv_layer.extend(list(model.features))  
model.features= nn.Sequential(*first_conv_layer )  
output = model(x)
1 Like

Thanks a lot ptrblck!

Assuming we use a pretrained vgg, let’s say we add a conv layer in the beggining (i.e. if we want to have as input a tensor with 2 channels instead of 3). And we also drop the last fc layers to match our output size. When training we usually freeze all params in model.features with requires_grad=False and we only train model.classifier part. Now that we also added a conv layer in the beginning, do we need to train this part from model.features as well? Example code:

def build_model(self):

        if self.model_type in ['vgg-16', 'vgg-19']:
            
            assert self.image_size == 224, "ERROR: Wrong image size."
            model = torchvision.models.vgg16(pretrained=True) if self.model_type == 'vgg-16' else torchvision.models.vgg19(pretrained=True)
            if self.input_ch != 3:
                first_conv_layer = [nn.Conv2d(self.input_ch, 3, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)]
                first_conv_layer.extend(list(model.features))  
                model.features= nn.Sequential(*first_conv_layer)  

            model.classifier[-1] = nn.Linear(4096, 1000)
            model.classifier.add_module('7', nn.ReLU())
            model.classifier.add_module('8', nn.Dropout(p=0.5, inplace=False))
            model.classifier.add_module('9', nn.Linear(1000, self.output_ch))
            model.classifier.add_module('10', nn.LogSoftmax(dim=1))
            
            for param in model.features.parameters(): # disable grad for trained layers
                param.requires_grad = False

            return model.to(self.device)```

Yes, you probably should train the first layer. Otherwise it’ll stay randomly initialized.
You could of course freeze all other layers, which are already pretrained. :wink:

I guess the following code should do the work:

            assert self.image_size == 224, "ERROR: Wrong image size."
            model = torchvision.models.vgg16(pretrained=True) if self.model_type == 'vgg-16' else torchvision.models.vgg19(pretrained=True)
            if self.input_ch != 3:
                first_conv_layer = [nn.Conv2d(self.input_ch, 3, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)]
                first_conv_layer.extend(list(model.features))  
                model.features= nn.Sequential(*first_conv_layer)  

            for param in model.features[1:].parameters(): # disable grad for trained layers
                param.requires_grad = False

Hi again,

I’m trying to train a pretrained vgg-16 on the Fashion MNIST dataset. I added an extra conv layer at the beggining and modified the last layers to end up to 10 nodes. Also, I applied nn.LogSoftmax(dim=1) at the end, because I’m using nn.NLLLoss(). I disabled gradients for all gradients on model.features, except the first layer, so only the first conv layer and model.classifier layers are trained. I ran my code in two PCs one with 5gb GPU memory and one with 6gb, and both failed. Then, I tried Google Collab on GPU and still it takes too long to run even one epoch. This is my code:

if self.model_type in ['vgg-16', 'vgg-19']:

            

            assert self.image_size == 224, "ERROR: Wrong image size."

            model = torchvision.models.vgg16(pretrained=True) if self.model_type == 'vgg-16' else torchvision.models.vgg19(pretrained=True)

            if self.input_ch != 3:

                first_conv_layer = [nn.Conv2d(self.input_ch, 3, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)]

                first_conv_layer.extend(list(model.features))  

                model.features= nn.Sequential(*first_conv_layer)  

            model.classifier[-1] = nn.Linear(4096, 1000)

            model.classifier.add_module('7', nn.ReLU())

            model.classifier.add_module('8', nn.Dropout(p=0.5, inplace=False))

            model.classifier.add_module('9', nn.Linear(1000, self.output_ch))

            model.classifier.add_module('10', nn.LogSoftmax(dim=1))

            

            for param in model.features[1:].parameters(): # disable grad for trained layers

                param.requires_grad = False

Is this normal, or am I doing something wrong?

UPDATE: For batch_size=64 it took more than 10 minutes for 1 epoch. Is this normal? I’m going to try a larger batch_size as well to see what happens.

UPDATE 2: For a batch size of 256 I got this error:

RuntimeError: CUDA out of memory. Tried to allocate 784.00 MiB (GPU 0; 15.90 GiB total capacity; 14.87 GiB already allocated; 323.88 MiB free; 14.88 GiB reserved in total by PyTorch)

Wow, 14.88 GB already reserved? Isn’t that too much?

Answered here.

I am building a classifier using MRIs with pretrained alexnet so my batch size has become the number of MRI slices for example one MRI have 30 slices so the input shape becomes [30, 3 , 256, 256] but i want to parallelize the training by passing batches of MRIs, lets say batches of 8 MRIs and the input shape will be [8, 30, 3, 256, 256]. How can i alter Alexnet or any pretrained model from torchvision to accept this input size?

You could reshape the input such that the batches and sliced are both in dim0, which would thus increase the batch size via x = x.view(-1, 3, 256, 256).
This would treat each slice as an own input in the same way as your previous approach.

Alternatively, you might want to treat the slice dimension as the depth dimension.
In that case, you would need to change the model architecture, since you would need 3D layers such as nn.Conv3d.

Now I am getting this error RuntimeError: CUDA out of memory. Tried to allocate 62.00 MiB (GPU 0; 8.00 GiB total capacity; 5.90 GiB already allocated; 55.97 MiB free; 5.96 GiB reserved in total by PyTorch)

This would mean that a batch size of 240 images is too large for your current device.
You could either decrease the batch size, the number of slices, or try to use torch.utils.checkpoint to trade compute for memory.

And what about if I have images bigger than 224? maybe like 562x562x3? what would be the best way to modify the network?

2 Likes

I didn’t get the 5th line

model.features= nn.Sequential(*first_conv_layer ) 

whats it is doing? I

This line unwraps the list and passes each element sequentially to the nn.Sequential container.
If you would pass the list directly, you would get a TypeError:

my_list = []
my_list.append(nn.Linear(1, 1))
my_list.append(nn.Linear(1, 1))
my_list.append(nn.Linear(1, 1))
module = nn.Sequential(my_list)
> TypeError: list is not a Module subclass

module = nn.Sequential(*my_list) # works
1 Like

I know we can add a layer before and fc at the end to match the no of classes. But when we train model by freezing the pretrained part, from where, the layer added at input side will get gradients… Even if it gets from the latest layer which is being trained does dimensions match and is it right to receive gradient from that layer, and its a partial derivative related calculations and it shd be receiving gradients from the layer next to it,

@ptrblck My virgin post will be dedicated here :slight_smile:

First off, I really appreciate your swift and detailed responses, most of the PyTorch forums I visited has your name. It is really not easy replying to so many issues and questions.

Now, although not a complete beginner, but I am relatively new to PyTorch, and I have decided that it is my native language from now. The question is relatively straightforward, for pretrained models like VGG16 or even more advanced models from pytorch-image-models, the question on whether I need to input the shape that the pretrained model is trained on remains a puzzle to me. From my basic understanding, Convolutional Layers are invariant of image size, but Dense Layers aren’t. But I was doing resizing of image sizes blatantly in PyTorch’s pipeline without any regards to why it worked like a charm.

In short, I use either albumentations or torchvision to resize my images to say 512x512 (as usually this image size seems to be the favourite among Kagglers). Then I will have a very simple class wrapper around nn.module as follows:

class CustomPretrainedModel(nn.Module):
    def __init__(self, pretrained: bool=True):
        super().__init__()

        self.model = create_pretrained_model_here(......)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, num_classes)
        

    def forward(self, input_neurons):
        output_predictions = self.model(input_neurons)
        return output_predictions

Then I run the pipeline without any hiccups. So my main question is, the flexibility to input any image size in PyTorch’s own native pretrained models/other famous PyTorch based repos is mainly due to the change in nn.AdaptiveAvgPool2d((x,x))? I am not too familiar with it so it would be great to know how pretrained models can handle all types of image sizes :slight_smile:

1 Like

Thanks for the kind words and yes you are correct. The input shapes are flexible due to the usage of the adaptive pooling layers, which make sure the output activation shape matches the desired feature dimension of the linear layers.
Note that larger inputs would work without a problem, but the model would be still limited to a min. input size, since the intermediate activations could become empty otherwise.

1 Like