Hello,
I am fairly new to PyTorch, was seeking for help in google, but nothing looks similar to my situation. So I want to train my network on CIFAR10 with augmented vgg11 ‘features’ layer. I will post my whole code stepwise, maybe it would ease the search of issue.
This is how training function look like:
def train(network, data, loss, opt):
total = 0
correct = 0
errors_up = []
network.train()
for inputs, labels in data:
inputs, labels = inputs.to(device), labels.to(device)
opt.zero_grad()
output = network(inputs)
losses = loss(output, labels)
losses.backward()
opt.step()
errors_up.append(losses.item())
_, predicted = torch.max(output, 1)
total += labels.size(0)
correct += np.sum(predicted[1].cpu().numpy() == labels.cpu().numpy())
accuracy = 100. * correct / total
print("Train Accuracy: {:.4f}".format(accuracy), f'[{correct}/{total}]')
print('Cumulative error sum: ', sum(errors_up))
return errors_up
Network:
class VGG(nn.Module):
def __init__(self, features: nn.Module, num_classes: int = 10):
super(VGG, self).__init__()
self.features = features
self.avg_pool = nn.AdaptiveAvgPool2d(output_size=2)
self.fc = nn.Sequential(
nn.Linear(512*2*2, 512),
nn.ReLU(True),
nn.Dropout(0.4),
nn.Linear(512, 512),
nn.ReLU(True),
nn.Dropout(0.4),
nn.Linear(512, num_classes)
)
def forward(self, x) :
x = self.features(x)
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
Next, I insert the ‘features’ layer into network:
from torchvision.models import vgg
VGG = vgg.vgg11(pretrained=False)
feats = VGG.children()
features = nn.Sequential(*list(feats)[0])
net = VGG(features=features).to(device)
net
So now i am loading pretrained weights from state_dict of VGG11 network to my features layer:
def cifar_vgg11():
model = net
pretrained_dict = torch.load('model/vgg11-bbd30ac9.pth')
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
Then, I load CIFAR10 train data and use torch.utils.data.sampler.RandomSampler for training on small part of dataset:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
data_train = datasets.CIFAR10(root=root, download=True, transform=transform, train=True)
sampler_train = torch.utils.data.sampler.RandomSampler(data_train, num_samples=num_samples, replacement=True)
dataload_train = torch.utils.data.DataLoader(dataset=data_train, batch_size=batch_size, num_workers=num_workers, sampler=sampler_train)
Finally, I initialize network, freeze the weights of ‘features’ portion of network and train for small amount of epochs:
network = cifar_vgg11()
for param in network.features.parameters():
param.requires_grad == False
network.to(device)
data = cifar10_subset_loader("CIFAR10")
opt = optim.Adam(network.parameters(), lr=0.0005)
loss = nn.CrossEntropyLoss()
epochs = 10
for epoch in range(0, epochs):
update(network, data, loss, opt)
And get approximately same values of loss and accuracy every time:
Train Accuracy: 9.3750 [48/512]
Cumulative error sum: 32.791540026664734
Train Accuracy: 11.1328 [57/512]
Cumulative error sum: 29.287562489509583
Train Accuracy: 10.9375 [56/512]
Cumulative error sum: 24.683462738990784
Train Accuracy: 8.7891 [45/512]
Cumulative error sum: 22.476744771003723
Train Accuracy: 12.8906 [66/512]
Cumulative error sum: 19.47743135690689
Train Accuracy: 9.7656 [50/512]
Cumulative error sum: 18.563248455524445
Train Accuracy: 8.3984 [43/512]
Cumulative error sum: 20.311830699443817
Train Accuracy: 12.8906 [66/512]
Cumulative error sum: 19.25118178129196
Train Accuracy: 11.1328 [57/512]
Cumulative error sum: 15.79049265384674
Train Accuracy: 13.2812 [68/512]
...
I’ve already tried different learning rates, as well as different optimizers, also tried different combos of losses (logit+crossentropy and log_softmax+nllloss).Additionally I’ve tried training on the whole trainset, which didn’t help either. If you see any issue in my code, I would really appreciate if you help me to figure it out, I am stuck with it for a fair amount of time already.
Regards,
Bollo7