Transfer Learning - Long Training Times || Loss not converging - GoogLeNet

Hi Community and thanks in advance for the help.

I am working on transfer learning - specifically GoogLeNet model with the Food101 Dataset. Code is below. I think everything is in order from data preprocessing through to training but the training times are incredibly long and the loss doesn’t seem to be converging. I was running it on CPU initially to make sure it was functioning fine before xfering to GPU to run. Code is below. Any advice or assistance would be greatly appreciated.

Model:

from torchinfo import summary 
for params in google_net.parameters():
    params.requires_grad = False
## Change the number of output params in the fc layer
num_classes = 101 ## output classes
google_net.fc = nn.Linear(1024, num_classes)

# Check Model Summary
sum = summary(google_net, 
        input_size=(128, 3, 256, 256), # make sure this is (batch_size, color_channels, height, width)
        verbose=0,
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)
print(sum)

Only trainable layer in the output of above is the .fc layer at the end. Output below:

Total params: 5,703,429
Trainable params: 103,525 ## .fc layer only
Non-trainable params: 5,599,904
Total mult-adds (G): 47.92

Input size (MB): 19.27
Forward/backward pass size (MB): 1651.82
Params size (MB): 22.81
Estimated Total Size (MB): 1693.90

Data PreProcess:

T_form_gog = transforms.Compose([transforms.CenterCrop(224),
                                transforms.Resize((256,256)),
                                transforms.ToTensor(), #standardise to 0-1
                                transforms.Normalize(
                                    mean = [0.485, 0.456, 0.406],
                                    std = [0.229, 0.224, 0.225])]) ## Taken from pytorch googlenet documentation 

trg_101_gog = torchvision.datasets.Food101(root = './data',
                                          split = 'train',
                                          transform = T_form_gog,
                                          download = True)

val_101_gog = torchvision.datasets.Food101(root = './data',
                                          split = 'test',
                                          transform = T_form_gog,
                                          download = True)

Data Loaders:

from torch.utils.data import DataLoader
Batch_size = 128

trg_load_gog = DataLoader(trg_101_gog,
                          batch_size  = Batch_size,
                          shuffle = True,
                          )
val_load_gog = DataLoader(val_101_gog,
                          batch_size = Batch_size,
                          shuffle = True)

Params:

from torch import nn 
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(google_net.parameters(), lr = 0.001)

Train Loop:

from datetime import datetime as DT
NUM_EPOCHS = 30
train_losses = []

for epoch in range(NUM_EPOCHS):
    train_loss = 0
    test_loss = 0
    tic = DT.now()
    for batch, (img, label) in enumerate(trg_load_gog):

        ## zero the grads
        optim.zero_grad()

        ## forward 
        output = google_net(img)

        ## calc losses 
        loss = loss_fn(output.squeeze(), label.long())

        ## prop losses
        loss.backward()

        ## update weights
        optim.step()

        ## update loss 
        train_loss += loss.item()
    toc = DT.now()
    r_time = toc-tic
    train_losses.append(train_loss)
    print(f'epoch: {epoch}, train loss: {train_loss}, start time: {tic}, run time:{r_time}')     

output is:

epoch: 0, train loss: 1524.6963007450104, run time:1:02:26.554898
epoch: 1, train loss: 1454.7811387777328, run time:1:07:53.805809
epoch: 2, train loss: 1419.673243880272, run time:1:07:17.671551
epoch: 3, train loss: 1395.2064596414566, run time:1:07:22.943297
epoch: 4, train loss: 1380.3496369123459, run time:1:07:15.002554
epoch: 5, train loss: 1365.1366027593613, run time:1:07:17.721393
Paused…

My initial thoughts are on the .fc. Should I add in other linear layers to make this more efficient/ effective. I was trying to use the original architecture as much as possible.

Any advice and/or assistance would be greatly appreciated.

Thanks again