I tried your model using a fake dataset:
torchvision.datasets.FakeData(
size=120,
image_size=(3, 40, 40),
num_classes=2,
transform=transform)
and can overfit in ~20 epochs, if I use a weight initialization:
def weight_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.xavier_uniform(m.weight.data)
m.bias.data.zero_()
Without the weight init it takes ~40 epoch to get a zero resubstitution error.
If you have an imbalanced dataset, you could use WeightedRandomSampler.
Yes, it’s ok to print the probabilities using torch.exp
as long as you don’t use it to calculate the loss etc.