Hey, I’ve been trying to train my custom squeeze-net implementation. For some reason its accuracy is really poor (like 0.014, should ideally be greater than 0.7 or 0.8). I implemented a fire module from the paper as follows:
class FireModule(nn.Module):
def __init__(self, in_channels, s1x1, e1x1, e3x3):
super(FireModule, self).__init__()
self.squeeze = nn.Conv2d(in_channels=in_channels, out_channels=s1x1, kernel_size=1, stride=1)
self.expand1x1 = nn.Conv2d(in_channels=s1x1, out_channels=e1x1, kernel_size=1)
self.expand3x3 = nn.Conv2d(in_channels=s1x1, out_channels=e3x3, kernel_size=3, padding=1)
def forward(self, x):
x = F.relu(self.squeeze(x))
x1 = self.expand1x1(x)
x2 = self.expand3x3(x)
x = F.relu(torch.cat((x1, x2), dim=1))
return x
And implemented SqueezeNet using the above module:
class SqueezeNet(nn.Module):
def __init__(self, out_channels):
super(SqueezeNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=7, stride=2)
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2)
self.fire2 = FireModule(in_channels=96, s1x1=16, e1x1=64, e3x3=64)
self.fire3 = FireModule(in_channels=128, s1x1=16, e1x1=64, e3x3=64)
self.fire4 = FireModule(in_channels=128, s1x1=32, e1x1=128, e3x3=128)
self.fire5 = FireModule(in_channels=256, s1x1=32, e1x1=128, e3x3=128)
self.fire6 = FireModule(in_channels=256, s1x1=48, e1x1=192, e3x3=192)
self.fire7 = FireModule(in_channels=384, s1x1=48, e1x1=192, e3x3=192)
self.fire8 = FireModule(in_channels=384, s1x1=64, e1x1=256, e3x3=256)
self.fire9 = FireModule(in_channels=512, s1x1=64, e1x1=256, e3x3=256)
self.dropout = nn.Dropout(p=0.5)
self.conv10 = nn.Conv2d(in_channels=512, out_channels=out_channels, kernel_size=1, stride=1)
self.avg_pool = nn.AvgPool2d(kernel_size=12, stride=1)
# We don't have FC Layers, inspired by NiN architecture.
def forward(self, x):
# First max pool after conv1
x = self.max_pool(self.conv1(x))
# Second max pool after fire4
x = self.max_pool(self.fire4(self.fire3(self.fire2(x))))
# Third max pool after fire8
x = self.max_pool(self.fire8(self.fire7(self.fire6(self.fire5(x)))))
# Final pool (avg in this case) after conv10
x = self.avg_pool(self.conv10(self.fire9(x)))
return torch.flatten(x, start_dim=1)
I tried applying relu right after conv1 (also tried it after conv10 as well) and tuning my learning rate a bit. But neither of those seem to change anything. The following is my training loop. (FYI, i did train a custom mobile-net implementation successfully with the same dataset I am using here, so its unlikely that something is wrong with the way I load my custom dataset). The following is my training loop for your reference:
if __name__ == '__main__':
LEARNING_RATE = 0.001
BATCH_SIZE = 64
EPOCHS = 10
NUM_CLASSES = 131
device = "cuda" if torch.cuda.is_available() else "cpu"
transform_img = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor()
])
train_val_dict, labels, classes = generate_dictionaries()
train_data = FruitsDataset(list_ids=train_val_dict['train'],
labels=labels,
idx2class=classes,
root_dir="../input/fruits/fruits-360/Training",
transforms=transform_img)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
val_data = FruitsDataset(list_ids=train_val_dict['val'],
labels=labels,
idx2class=classes,
root_dir="../input/fruits/fruits-360/Training",
transforms=transform_img)
val_loader = DataLoader(dataset=val_data, batch_size=BATCH_SIZE, shuffle=True)
squeezenet = SqueezeNet(NUM_CLASSES)
squeezenet.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(squeezenet.parameters(), lr=LEARNING_RATE)
# data, targets = next(iter(train_loader))
for epoch in tqdm(range(EPOCHS)):
losses = []
with tqdm(total=len(train_val_dict['train']) // BATCH_SIZE) as pbar:
for batch_idx, (data, targets) in enumerate(train_loader):
data = data.to(device=device)
targets = targets.to(device=device)
scores = squeezenet(data)
loss = criterion(scores, targets)
losses.append(loss)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print(loss.item())
pbar.update(1)
print("Cost at epoch {} is {}".format(epoch, sum(losses) / len(losses)))
print("Calculating Validation Accuracy...")
check_accuracy(val_loader, squeezenet)
print("Calculating Train Accuracy...")
check_accuracy(train_loader, squeezenet)