So, I have been working with an unsupervised model that learned from the NORB dataset (images of toys) and I wanted to create a version that learns from samples of variations of just one toy in order to figure out the angle variations, instead of using all toys for all classes as I was doing before for traditional classification.
The thing is, using exactly the same code but using only one toy leads me to a situation where the loss is pretty much constant up to several decimal places for any learning rate lower than approximately 1e-2, and above that it’s just crazy high. It was working just fine until I reduced the dataset to a single toy.
I have tried many learning rates, adding and removing layers, using a different toy…
My loss looks like this:
def infoNCE(zt, ztk):
ind = np.diag_indices(b_size)
mul = ztk @ torch.t(zt)
num = mul[ind[0], ind[1]]
m, _ = torch.max(mul, axis=0)
den = torch.log(torch.sum(torch.exp(mul - m), axis=0)) + m
val = torch.mean(-num + den)
return val
My model:
class Module1(nn.Module):
def __init__(self):
super(Module1, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 5, 1, 2)
self.conv2 = nn.Conv2d(32, 32, 5, 1, 2)
self.conv3 = nn.Conv2d(32, 64, 5, 1, 2)
self.conv4 = nn.Conv2d(64, 64, 5, 1, 2)
self.conv5 = nn.Conv2d(64, 128, 5, 1, 2)
self.conv6 = nn.Conv2d(128, 128, 5, 1, 2)
self.pool = nn.MaxPool2d(2, 2)
self.fc = nn.Linear(128*8*8, 3)
def forward(self, x):
x = self.pool(F.relu(self.conv2(F.relu(self.conv1(x)))))
x = self.pool(F.relu(self.conv4(F.relu(self.conv3(x)))))
x = self.pool(F.relu(self.conv6(F.relu(self.conv5(x)))))
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
My training routine:
mod1 = Module1().to(device)
optimizer = optim.Adam(mod1.parameters(), lr=lr)
for epoch in trange(epochs):
for i, data in enumerate(trainloader):
xt = data[0][0].to(device)
xtk = data[0][1].to(device)
optimizer.zero_grad()
zt = mod1(xt)
ztk = mod1(xtk)
loss = infoNCE(zt, ztk)
loss.backward()
optimizer.step()
The loss is always around 2.77, because that is usually the value of the variable den
, while num
is kinda small (this didn’t happen when I used all of the samples).
EDIT: Yes, this number is suspiciously close to e
and I’m using logarithms.