Transfer Learning Only Learns Last Class

Hi,

I’m trying to train a new classifier on top of the pretrained VGG16 feature module with requires_grad=False for all feature layers. Even though it converges for each of my 5 classes during training, my verification images are always scored highest for the last class that it is trained on.

I am wondering if this is because of the way that my training is set up: I have 5 folders, one for each class. The training loop opens each folder, then feeds each of the images in the folder into a Variable for the forward pass, computes loss, does backprop, etc. for 2 epochs. It then moves on to the folder of the next class and does this all again. So something like:

root_directory -> [class1, class2, class3, class4, class5] -> class1[image_0, image_1,…image_n] (iterate over all of these twice) -> class2…

This is the classifier that I’m using:

self.classifier = nn.Sequential(
    nn.Dropout(inplace=False),
    nn.Linear(32768, 500),  # for 256x256 images
    nn.ReLU(inplace=False),
    nn.Dropout(inplace=False),
    nn.Linear(500, 100),
    nn.ReLU(inplace=False),
    nn.Linear(100, 5)
)

Attached image is the graph of loss for the whole training cycle.

Thanks for any help!

Edit:

So I moved on to trying to use Resnet18 as the basis for this same transfer learning problem. I was thinking that my classifier might have been the problem, so I tried to follow the transfer learning tutorial as closely as possible for my own dataset. This is the training function that I am using:

def train_network(net, img_path_list, target_dict, criterion, optimizer, loss_graph, test_dirs=None, epochs=2, print_frequency=20, test_print=False):
    """train a network on a directory of images"""

    trained_net = net

    s = 0  # generic counter for doing stuff
    for path in img_path_list:

        target = target_gen(target_dict, path)
        goal = Variable(target).cuda()
        print(path)
        min_loss = 1000  # absurdly large loss

        for i in range(epochs):
            for filename in os.listdir(path):

                input = image_loader(path + '/' + filename)

                optimizer.zero_grad()

                output = net(input)

                if s % print_frequency == 0 and test_print is True:
                    print('\n')
                    print('output: ', output.data)
                    print('goal: ', goal.data)
                    print('\n')

                loss = criterion(output, goal)
                loss.backward()
                optimizer.step()

                loss_graph += [loss.data[0]]
                if s % print_frequency == 0:
                    print('epoch: {} '.format(i + 1), 'loss: {}'.format(loss.data[0]))

                if loss.data[0] < min_loss:
                    min_loss = loss.data[0]
                    print('new min loss: {}'.format(min_loss))
                    trained_net = copy.deepcopy(net)

                s += 1


    print('done!')
    return trained_net

But it still has the same problem, no matter at what point I test the network on my validation set, it only ever responds with the highest probability for the last class it was trained on. This is in spite of it being able to make strong correlations during training, i.e.

output:  
-0.0148  0.8499  0.0970 -0.0274  0.0879
[torch.cuda.FloatTensor of size 1x5 (GPU 0)]
goal:  
 0
 1
 0
 0
 0
[torch.cuda.FloatTensor of size 5 (GPU 0)]

Do the input and target variables need to be in a minibatch format that contains examples of multiple classes so that all correlations are affected at every forward pass? I cannot figure out what is going on…

Hello,

I would expect your results to improve if you shuffle your training data accross classes.
To do this elegantly it is likely a good idea to define a torch.utils.data.Dataset for your dataset and then use a DataLoader (with shuffle=True).

One epoch commonly is considered be a loop over the entire dataset (all classes) so that loop should be the outermost.

Best regards

Thomas

1 Like

Thanks, Tom!

This seems to be helping. I had already started making a batching script, but the data_utils help a lot.