How can I feed a 2D intermediate feature and not a 2D image with 3 channel to Inception V3 for fine-tuning?

Here’ the model I am trying to use for fine-tuning of a pre-trained Inception V3 on a 2-class binary classification problem which has inputs of shapes batch_size , number of patches, resnet18 extracted feature dimension (512 here), so if batch_size=64 and the large initial image has 6000 patches, I end up with 64, 6000, 512 as an intermediate representation. I want to use Inception V3 now to fine-tune on my own labels and my own intermediate representations. How can I do this? current error is RuntimeError: Expected 4-dimensional input for 4-dimensional weight [32, 3, 3, 3], but got 3-dimensional input of size [16, 3, 512] instead shown fully at the end of the post.

model_ft = models.inception_v3(pretrained=True)
for param in model_ft.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
# Handle the auxilary net
num_ftrs = model_ft.AuxLogits.fc.in_features
model_ft.AuxLogits.fc = nn.Linear(num_ftrs, 2)
# Handle the primary net
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
model = model_ft
model = nn.DataParallel(model)

if args.resume:
    print('load model{}'.format(args.resume))
    model.load_state_dict(torch.load(args.resume))

if torch.cuda.is_available():
    model = model.cuda()

and

# training, validation, and test phases

for epoch in range(num_epochs):

    train_loss = 0.
    total = 0.

    current_lr = optimizer.param_groups[0]['lr']
    print('\n=>Epoches %i, learning rate = %.7f, previous best = %.4f' %
          (epoch+1, current_lr, best_val_acc))
    
    train_epoch_loss = 0
    train_epoch_labels = []
    train_epoch_preds = []
    
    val_epoch_loss = 0
    val_epoch_labels = []
    val_epoch_preds = []
    val_epoch_labels_arr = []
    val_epoch_preds_arr = []
    

    epoch_loss = 0
    epoch_accuracy = 0
    
    if train:
        exp_lr_scheduler.step()
        print('training...')
        torch.autograd.set_detect_anomaly(True)
        for i_batch, sample_batched in enumerate(dataloader_train):  
            print(type(sample_batched))
            print(sample_batched.keys())
            # sample_batched['image'] is an array of tensors with len of batch_size
            feats = torch.stack(sample_batched['image']) 
            print("feature size shape: ", feats.shape)
            labels = torch.as_tensor(sample_batched['label']).cuda() 
            output = model(feats)
            loss = kornia.losses.focal_loss(output, labels, **kwargs)
            print('train loss is: ', loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            acc = (output.argmax(dim=1) == labels).float().mean()
            train_preds = output.argmax(dim=1)
            print('train preds are: ', train_preds)
            train_epoch_preds.extend(train_preds.cpu().numpy())
            train_epoch_labels.extend(labels.cpu().numpy())
            epoch_accuracy += acc / len(dataloader_train)
            epoch_loss += loss / len(dataloader_train)
            print('epoch accuracy: ', epoch_accuracy)
            
        train_epoch_accuracy = accuracy_score(train_epoch_labels, train_epoch_preds)
        print('train_epoch_accuracy: ', train_epoch_accuracy)

    if not test:
        print('not test')

    with torch.no_grad():
        print('Evaluating...')
        print('epoch is: ', epoch)
        
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        epoch_val_preds = []
        epoch_val_labels = []
        
        #model.eval()
        print("evaluating...")
        total = 0.
        batch_idx = 0
        val_preds = []
        val_labels = []
        predictions = []
        actuals = []
        
        
        for i_batch, sample_batched in enumerate(dataloader_val):
            feats = torch.stack(sample_batched['image']) 
            labels = torch.as_tensor(sample_batched['label']).cuda() 
            val_output = model(feats)
            val_loss = kornia.losses.focal_loss(val_output, labels, **kwargs)
            acc = (val_output.argmax(dim=1) == labels).float().mean()

Error is:


=>Epoches 1, learning rate = 0.0000000, previous best = 0.0000
training...
<class 'dict'>
dict_keys(['image', 'label', 'id'])
feature size shape:  torch.Size([64, 419, 512])

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [68], in <module>
     35 print("feature size shape: ", feats.shape)
     36 labels = torch.as_tensor(sample_batched['label']).cuda() 
---> 37 output = model(feats)
     38 loss = kornia.losses.focal_loss(output, labels, **kwargs)
     39 print('train loss is: ', loss)

File ~/research/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py:1102, in Module._call_impl(self, *input, **kwargs)
   1098 # If we don't have any hooks, we want to skip the rest of the logic in
   1099 # this function, and just call forward.
   1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102     return forward_call(*input, **kwargs)
   1103 # Do not call functions when jit is used
   1104 full_backward_hooks, non_full_backward_hooks = [], []

File ~/research/venv/dpcc/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py:168, in DataParallel.forward(self, *inputs, **kwargs)
    166     return self.module(*inputs[0], **kwargs[0])
    167 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 168 outputs = self.parallel_apply(replicas, inputs, kwargs)
    169 return self.gather(outputs, self.output_device)

File ~/research/venv/dpcc/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py:178, in DataParallel.parallel_apply(self, replicas, inputs, kwargs)
    177 def parallel_apply(self, replicas, inputs, kwargs):
--> 178     return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

File ~/research/venv/dpcc/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py:86, in parallel_apply(modules, inputs, kwargs_tup, devices)
     84     output = results[i]
     85     if isinstance(output, ExceptionWrapper):
---> 86         output.reraise()
     87     outputs.append(output)
     88 return outputs

File ~/research/venv/dpcc/lib/python3.8/site-packages/torch/_utils.py:434, in ExceptionWrapper.reraise(self)
    430 except TypeError:
    431     # If the exception takes multiple arguments, don't try to
    432     # instantiate since we don't know how to
    433     raise RuntimeError(msg) from None
--> 434 raise exception

RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/jalal/research/venv/dpcc/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/jalal/research/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jalal/research/venv/dpcc/lib/python3.8/site-packages/torchvision/models/inception.py", line 200, in forward
    x, aux = self._forward(x)
  File "/home/jalal/research/venv/dpcc/lib/python3.8/site-packages/torchvision/models/inception.py", line 139, in _forward
    x = self.Conv2d_1a_3x3(x)
  File "/home/jalal/research/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jalal/research/venv/dpcc/lib/python3.8/site-packages/torchvision/models/inception.py", line 472, in forward
    x = self.conv(x)
  File "/home/jalal/research/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/jalal/research/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 446, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/jalal/research/venv/dpcc/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 442, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [32, 3, 3, 3], but got 3-dimensional input of size [16, 3, 512] instead

Hi Mona!

Regardless of whether you “fix” any specific technical errors that arise,
I don’t think that fine-tuning with your intermediate representation as
input to the pre-trained model will work.

The problem is that all of the layers of the model have been trained to
work together, and, as there is a lot of arbitrariness in the model weights,
the specific weights in, say, layer 4 have to “align” properly with the
specific weights in, say, layer 1. To say this another way, if you trained
the model twice from scratch (with different random initializations, etc.),
you couldn’t take layer 1’s weights from training run 1 and layer 4’s
weights from training run 2, and have your model work.

The first several layers of the inception model are convolutions that
expect 2d images, rather than your intermediate representation. So you
can’t feed your intermediate representation into these early layers. And
if you try to inject your intermediate representations into the model further
along, say before something we will call “layer 4,” you will have bypassed
layer 1 entirely. Your intermediate representation knows nothing about the
layer-1 weights, so they can’t be meaningful to the specific weights in
layer 4, whose meanings are dependent on the specific weights in the
“co-trained” layer 1.

You could argue that having the weights in the layers you do keep be
“properly aligned” is an advantage, but my intuition is that starting with
those pre-trained weights won’t really be any better that starting with
a random initialization.

In any event, what you propose is very different than keeping the
architecture of a pre-trained model fixed (so that all the weights in all
the layers play nice with one another) and then fine tuning the model’s
weights on a different dataset (which is an established technique that
works quite well).

Best.

K. Frank