Fusing Additional input/features for Transfer Learning

Hey! I am new to Pytorch. I am trying to do transfer learning and at some point, while doing the training I want to give some additional input for each data point. So let’s say image_1 has a feature like (1.01, 0.5, .06) and image_2 has a feature like (.9, .001, .023) and so on, I would like to bundle them together.

The notebook I am following: https://github.com/WillKoehrsen/pytorch_challenge/blob/master/Transfer%20Learning%20in%20PyTorch.ipynb

So here it shows how we can change the fc layer at the end and insert whatever we want to. For my problem would I then need to change the very first Conv2D layer (for resnet18) and somehow pass in the features? For example, adding a fc layer there and fusing the data? Or are there other ways to do this like fusing on a tensor level?

In a programming sense, how should I go about doing this?

Thanks

What kind of additional data do you have?
Is it also an image or do you have any other features (e.g. 1-dimensional) for each image?

Thank you so much for getting back!

Additional data is numeric. Features would be 2 to 4 dimensional, definitely not more than that.

So after going through the vision forum, I think I found some other posts that were very similar to what I was trying to do. I am referring to this post specifically. But then I have some other questions from there as well. I have a total of 250 something images over two classes. Classes are named, PD and Non_PD.

So let’s say I have csv file with the header: filename, feature_1, feature_2, feature_3, feature_4. Now the files in filename are all located in PD and Non_PD folders but there are no indications in the csv file which are in which class (I could add it if it would make my life easy).

That being said, for my problem what would additional_data_dimension be as referred to in the other post I just mentioned? 4? How would I read in the data here? Through DataLoader? Since right now my code is getting all the images from the train directory. How would the code look like for that?

Based on the file description, it seems your additional features are 1-dimensional with a size of 4.
What is stored in the filename column? Are there your image file names (or paths)?
If so, I would write a custom Dataset, index the DataFrame inside __getitem__ to get the file path to the image, the additional features, and the current target (I would recommend to create the target in a separate column).

Let me know, if that would work for you.

Yup that’s how it should work I believe!

The custom Dataset class looks like this:

class PD_Dataset(Dataset):

	def __init__(self, csv_file, root_dir, transform=None):
		self.pd_features = pd.read_csv(csv_file)
		self.root_dir = root_dir
		self.transform = transform
		self.classes = ("PD", "Non-PD")

	def __len__(self):
		return len(self.pd_features)

	def __getitem__(self, idx):
		if torch.is_tensor(idx):
			idx = to_list()

		img_name = os.path.join("",self.pd_features.iloc[idx, 0])
		image = Image.open(img_name)
		features = self.pd_features.iloc[idx, 1]
		labels = self.pd_features.iloc[idx, 2]

		if self.transform:
			image = self.transform(image)
		return image, features, labels

Now one thing I have to figure out is changing the features variable here so that its shape follows image variable’s shape. features variable looks like [(x1, x2, x3)], and I was told that I need to pad a tensor with the same values so that shape becomes the shape of image.

class MyModel(nn.Module):
	def __init__(self):
		super(MyModel, self).__init__()
		self.cnn = models.resnet18(pretrained=True)
		self.cnn.fc = nn.Sequential(
			nn.Linear(self.cnn.fc.in_features, 256),
			nn.ReLU(),
			nn.Dropout(0.2),
			nn.Linear(256, 2),
			nn.LogSoftmax(dim=1))

	def forward(self, image, data):
		x1 = self.cnn(image)
		x2 = data

		x = torch.cat((x1, x2), dim=1)
		x = self.cnn.fc(x)
		return x

model = MyModel()

Now for the training part:

for epoch in range(n_epochs):

		train_loss = 0.0
		valid_loss = 0.0

		train_acc = 0
		valid_acc = 0

		model.train()
		start = timer()

		for ii, (data, target) in enumerate(train_loader):
			if train_on_gpu:
				data, target = data.cuda(), target.cuda()

			optimizer.zero_grad()

			output = model(data)

			target_tensor = torch.tensor(target, dtype=torch.long, device=torch.device('cuda'))
			loss = criterion(output, target_tensor)
			loss.backward()

			optimizer.step()

			train_loss += loss.item() * data.size(0)

			_, pred = torch.max(output, dim=1)
			correct_tensor = pred.eq(target.data.view_as(pred))
			accuracy = torch.mean(correct_tensor.type(torch.FloatTensor))
			train_acc += accuracy.item() * data.size(0)

Since my Dataset now will return in the format of image, features, data, do you I have to enumerate the train_loader a different way? And manually handle it? The reason I ask is, I am not seeing where the forward function’s inputs are getting taken care of.

Would appreciate some help on this part!

Great to see that your data loading pipeline seems to work.

I’m a bit skeptical about your forward implementation.
Currently you are reusing self.cnn.fc. The image tensor batch will use it once in x1 = self.cnn(image) and then again the concatenated tensor in x = self.cnn.fc(x).
Is this your workflow (it might make sense) or would you rather concatenate the penultimate resnet18 output of the image tensor with your additional data?

In the latter case, replace the last linear layer with self.cnn.fc = nn.Identity() and define the new classification block as self.fc = nn.Sequential(...).
If you are using this approach you won’t need to pad your additional feature tensor to an image size, and can just concatenate it in the feature dimension (dim1) as is already done in your code.

Take care of the scaling of the additional features, as we’ve seen some issue with this approach in the past, when the value ranges are quite different.
E.g. if x1 has values in [0, 1], while x2 has values in [0, 100], the last block might just “focus” on the larger values and treat the smaller ones as noise.

Since you are now expecting three output tensors in your DataLoader, use:

for ii, (data1, data2, target) in enumerate(train_loader):

and feed both data tensors to the model via model(data1, data2).

Thanks for the quick response!
x1 has values in [600, 800], x2 in [100, 200] and x3 in [10, 30]

So the model would become something like that if I didn’t wanna bother with padding:

class MyModel(nn.Module):
	def __init__(self):
		super(MyModel, self).__init__()
		self.cnn = models.resnet18(pretrained=True)
        self.cnn.fc = nn.Identity()
		self.fc = nn.Sequential(
			nn.Linear(self.cnn.fc.in_features, 256),
			nn.ReLU(),
			nn.Dropout(0.2),
			nn.Linear(256, 2),
			nn.LogSoftmax(dim=1))

	def forward(self, image, data):
		x1 = self.cnn.fc(image)
		x2 = data

		x = torch.cat((x1, x2), dim=1)
		x = self.fc(x)
		return x

model = MyModel()

Did I understand you correctly?

Some minor mistakes:

  • Change x1 = self.cnn.fc(image) to x1 = self.cnn(image).
  • self.cnn.fc.in_features will be undefined, as it’s not an nn.Identity module. You could flatten x1 after the cnn forward pass, print out the shape and define it in your model definition.
# Define the in_features of the linear layer with a random number, as we want to throw an error
def forward(self, image, data):
    x1 = self.cnn(x1)
    x1 = x1.view(x1.size(0), -1)
    print(x1.shape)
    x2 = data
    print(x2.shape)
    x = torch.cat((x1, x2), dim=1)
    x = self.fc(x) # will throw an error, but we got the shapes!
    return x

Use a single forward pass with your input tensors, look at the printed shapes, and define the in_features accordingly.

Ah gotcha! I will try this out tomorrow when I am near the GPU machine.

After this project I really need to learn PyTorch ground up, I will have to work on a more heavy duty vision project than what I am currently doing. All these lessons will come in very handy then!