Transfer learning with different inputs

Hi,

I have used the transfer learning example provided on the website and it works pretty well. So I plan to apply pre-trained model, such as ResNet18, to different modalities, and then fuse the two models at FC layer to continue the training? Any thoughts on how this could be implemented here? The network parameters could either be shared or not shared. Thanks.

1 Like

What do you mean by “different modalities”?
As far as I understand you would like to get two pre-trained models and concat their activations at some point?

To clarify, for instance, one input is from the raw image, and another input is from the depth image.

Yes, concatenation at an early stage or a late stage should be fine. Thanks.

Ah ok, I see.
So I suppose the images have different numbers of channels, i.e. image has 3 channels while depth has 1 channel?
I created a small code snippet:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        
        image_modules = list(models.resnet18().children())[:-1]
        self.modelA = nn.Sequential(*image_modules)

        depth_modules = list(models.resnet18().children())[:-1]
        self.modelB = nn.Sequential(nn.Conv2d(1, 3, 3, 1, 1),
                                    *depth_modules)
        
        self.fc = nn.Linear(1024, 1)
        
    def forward(self, image, depth):
        a = self.modelA(image)
        b = self.modelB(depth)
        x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), dim=1)
        x = self.fc(x)
        x = F.sigmoid(x)
        return x


x_image = Variable(torch.randn(1, 3, 224, 224))
x_depth = Variable(torch.randn(1, 1, 224, 224))

model = MyModel()
output = model(x_image, x_depth)

Does it suit your needs?

4 Likes

Alternatively, you could add the depth information as a fourth channel and edit the first layer of resnet18 so that it takes 4 input channels instead of three.

to steal from @ptrblck’s nice example:


x_image = Variable(torch.randn(1, 3, 224, 224))
x_depth = Variable(torch.randn(1, 1, 224, 224))

input = torch.cat(x_image, x_depth, dim=1) # RGBD input

model = resnet18()
model.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)

output = model(input)

I’m not sure which would work better for your purposes but this saves a lot of parameters vs the siamese method.

4 Likes

what changes are to be done for the data loader here?

Basically none. Your Dataset should provide the samples containing 4 channels.

but custom class has to be written for the data loader if i am right.

You would write a custom Dataset and just pass it to the DataLoader.
Here is a small example:


class MyDataset(Dataset):
    def __init__(self, image_paths, targets, transform=None):
        self.image_paths = image_paths
        self.targets = targets
        self.transform = transform
        
    def __getitem__(self, index):
        x = Image.open(self.image_paths[index])
        y = self.targets[index]
        if self.transform:
            x = self.transform(x)
            
        return x, y
    
    def __len__(self):
        return len(self.image_paths)

dataset = MyDataset(image_paths, targets)
loader = DataLoader(
    dataset,
    batch_size=10,
    shuffle=True
)
1 Like

Hey,

The input channel size in my case is 19, not 3. I am looking to use the pre-trained resnet18. I found a way to handle keras, but not in PyTorch. Do you have any idea?

I don’t know how you handle it in Keras, but you could:

  • replace the first conv layer with a new one accepting 19 channels
  • manipulate the first conv layer by increasing the in_channel dimension of the weight (e.g. by repeating etc.)
  • add an entirely new conv layer (let’s call it layer0) which would accept 19 channels and return an activation with 3 channels
  • manipulate the input data and reduce the channel dimension to 3
  • split the input in some way and pass only 3 channels to the model

I don’t know which approach would work the best and which you are already using, but I would probably just replace the first conv layer and make sure the new layer is accepting 19 channels.

Hi!
I’m working on a similar model, just instead of depth images, I need to add numeric input. What does this star mean in line:

self.modelA = nn.Sequential(*image_modules)

The asterisk (*) unpacks the list and passes the content as positional arguments to nn.Sequential.

Thanks!
I’ve made some changes to the model according to my input.
Image input is torch.Size([1, 4, 224, 224]), numeric input (48 landmarks with 2 coordinates each) torch.Size([1, 48, 2]) and a label with torch.Size([1]).

class MixedNetwork(nn.Module):
    def __init__(self):
        super(MixedNetwork, self).__init__()
        
        image_modules = list(models.resnet18().children())[:-1]
        self.image_features = nn.Sequential(*image_modules)

        self.landmark_features = nn.Sequential(
            nn.Linear(in_features=2,out_features=16,bias=False), 
            nn.ReLU(inplace=True), 
            nn.Dropout(p=0.25),
            nn.Linear(in_features=16,out_features=2,bias=False), 
            nn.ReLU(inplace=True), 
            nn.Dropout(p=0.25))
        
        self.combined_features = nn.Sequential(
            nn.Linear(32,16),
            nn.ReLU(),
            nn.Linear(16,2)
        )
        
    def forward(self, image, landmarks):
        a = self.image_features(image)
        b = self.landmark_features(landmarks)
        x = torch.cat((a.view(a.size(0), -1), b.view(b.size(0), -1)), dim=1)
        x = self.combined_features(x)
        x = F.sigmoid(x)
        return x

I keep getting the error:
RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 1, 4, 224, 224]

Here I don’t understand why there’s an extra one in the size - the input tensor is [1, 4, 224, 224]. Secondly, the picture is colored, so 4 channels are okay for resnet18, right?

Based on the error message your input is 5-dimensional, so print the shape inside the forward and check it again.
A 4-dimensional input works fine:

image_modules = list(models.resnet18().children())[:-1]
image_features = nn.Sequential(*image_modules)

x = torch.randn(1, 3, 224, 224)
out = image_features(x)
print(out.shape)
# torch.Size([1, 512, 1, 1])

No, as 3 input channels are expected. You could either remove the alpha channel or replace the first conv layer with a new one accepting 4 channels.

Thanks! First, I got rid of the alpha channel and now the image tensor is of 3 channels.
I also tried the dimensions check:

for i in range(len(train_dataloader)):
    sample = train_dataset[i]
    print(i, sample['image'].size(), sample['landmarks'].size(), sample['labels'].size())

    if i == 0:
      break

image_modules = list(models.resnet18().children())[:-1]
image_features = nn.Sequential(*image_modules)
image_features.double()

out = image_features(sample["image"])
print(out.shape)

out.shape was torch.Size([1, 512, 1, 1]), as you wrote.
However, when I tried to apply the whole model with the same dataloader, it didn’t work.

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 1, 3, 224, 224]

This is the same dataloader, what else could cause this dimensionality issue?

I don’t know as I don’t fully understand the difference between the working use case and the failing one.
Could you post a minimal, executable code snippet showing the failure please?

Sorry for not making myself clear.
What I meant is the following:
I have a training dataset consisting of images, landmarks on the image, and labels.

train_dataset = FaceLandmarksDataset(train_table, images_path, 
                                    transform=transforms.Compose([scale, ToTensor()]))
sample = train_dataset[3]
print(sample['image'].size(), sample['landmarks'].size(), sample['labels'].size())

Print results: torch.Size([1, 3, 224, 224]) torch.Size([1, 48, 2]) torch.Size([1, 1])

Then I make a dataloader using this dataset and print out the tensor sizes:

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0)

for i_batch, sample_batched in enumerate(train_dataloader):
    print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size(), sample_batched['labels'].size())

    if i_batch == 1:
      break

Print result: torch.Size([1, 1, 3, 224, 224]) torch.Size([1, 1, 48, 2]) torch.Size([1, 1, 1])
So, while getting a dataloader from a dataset, an extra dimension is added to the tensor. This is what I don’t understand.

These shapes are unexpected as the Dataset already returns samples with a batch dimension.
In the standard use case the batch dimension would only be added in the DataLoader e.g. as seen in this example:

dataset = datasets.CIFAR10(root='./data', download=False, transform=transforms.ToTensor())

x, y = dataset[0]
print(x.shape)
# torch.Size([3, 32, 32])
print(y)
# 6

loader = DataLoader(dataset, batch_size=1)
x, y = next(iter(loader))
print(x.shape)
# torch.Size([1, 3, 32, 32])
print(y)
# tensor([6])

Check how the __getitem__ method is defined in your FaceLandmarksDataset and make sure you are not adding a batch dimension.
Note that calling x = x.squeeze(1) on the data batch inside the DataLoader would also just remove the additional dimension, but it might be a better idea to properly fix the addition of the unneeded dimension.

Thanks a lot! I understand now: fixed the getitem and now it’s working properly.