So, I have been working with an unsupervised model that learned from the NORB dataset (images of toys) and I wanted to create a version that learns from samples of variations of just one toy in order to figure out the angle variations, instead of using all toys for all classes as I was doing before for traditional classification.
The thing is, using exactly the same code but using only one toy leads me to a situation where the loss is pretty much constant up to several decimal places for any learning rate lower than approximately 1e-2, and above that it’s just crazy high. It was working just fine until I reduced the dataset to a single toy.
I have tried many learning rates, adding and removing layers, using a different toy…
My loss looks like this:
def infoNCE(zt, ztk): ind = np.diag_indices(b_size) mul = ztk @ torch.t(zt) num = mul[ind, ind] m, _ = torch.max(mul, axis=0) den = torch.log(torch.sum(torch.exp(mul - m), axis=0)) + m val = torch.mean(-num + den) return val
class Module1(nn.Module): def __init__(self): super(Module1, self).__init__() self.conv1 = nn.Conv2d(1, 32, 5, 1, 2) self.conv2 = nn.Conv2d(32, 32, 5, 1, 2) self.conv3 = nn.Conv2d(32, 64, 5, 1, 2) self.conv4 = nn.Conv2d(64, 64, 5, 1, 2) self.conv5 = nn.Conv2d(64, 128, 5, 1, 2) self.conv6 = nn.Conv2d(128, 128, 5, 1, 2) self.pool = nn.MaxPool2d(2, 2) self.fc = nn.Linear(128*8*8, 3) def forward(self, x): x = self.pool(F.relu(self.conv2(F.relu(self.conv1(x))))) x = self.pool(F.relu(self.conv4(F.relu(self.conv3(x))))) x = self.pool(F.relu(self.conv6(F.relu(self.conv5(x))))) x = x.view(x.size(0), -1) x = self.fc(x) return x
My training routine:
mod1 = Module1().to(device) optimizer = optim.Adam(mod1.parameters(), lr=lr) for epoch in trange(epochs): for i, data in enumerate(trainloader): xt = data.to(device) xtk = data.to(device) optimizer.zero_grad() zt = mod1(xt) ztk = mod1(xtk) loss = infoNCE(zt, ztk) loss.backward() optimizer.step()
The loss is always around 2.77, because that is usually the value of the variable
num is kinda small (this didn’t happen when I used all of the samples).
EDIT: Yes, this number is suspiciously close to
e and I’m using logarithms.