Custom loss not decreasing after reducing the dataset

So, I have been working with an unsupervised model that learned from the NORB dataset (images of toys) and I wanted to create a version that learns from samples of variations of just one toy in order to figure out the angle variations, instead of using all toys for all classes as I was doing before for traditional classification.

The thing is, using exactly the same code but using only one toy leads me to a situation where the loss is pretty much constant up to several decimal places for any learning rate lower than approximately 1e-2, and above that it’s just crazy high. It was working just fine until I reduced the dataset to a single toy.

I have tried many learning rates, adding and removing layers, using a different toy…

My loss looks like this:

def infoNCE(zt, ztk):
  ind = np.diag_indices(b_size)
  mul = ztk @ torch.t(zt)
  num = mul[ind[0], ind[1]]
  m, _ = torch.max(mul, axis=0)
  den = torch.log(torch.sum(torch.exp(mul - m), axis=0)) + m
  val = torch.mean(-num + den)
  return val

My model:

class Module1(nn.Module):

  def __init__(self):
    super(Module1, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 5, 1, 2)
    self.conv2 = nn.Conv2d(32, 32, 5, 1, 2)
    self.conv3 = nn.Conv2d(32, 64, 5, 1, 2)
    self.conv4 = nn.Conv2d(64, 64, 5, 1, 2)
    self.conv5 = nn.Conv2d(64, 128, 5, 1, 2)
    self.conv6 = nn.Conv2d(128, 128, 5, 1, 2)
    self.pool = nn.MaxPool2d(2, 2)
    self.fc = nn.Linear(128*8*8, 3)

  def forward(self, x):
    x = self.pool(F.relu(self.conv2(F.relu(self.conv1(x)))))
    x = self.pool(F.relu(self.conv4(F.relu(self.conv3(x)))))
    x = self.pool(F.relu(self.conv6(F.relu(self.conv5(x)))))
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

My training routine:

mod1 = Module1().to(device)
optimizer = optim.Adam(mod1.parameters(), lr=lr)

for epoch in trange(epochs):
  for i, data in enumerate(trainloader):
    xt = data[0][0].to(device)
    xtk = data[0][1].to(device)
    zt = mod1(xt)
    ztk = mod1(xtk)
    loss = infoNCE(zt, ztk)

The loss is always around 2.77, because that is usually the value of the variable den, while num is kinda small (this didn’t happen when I used all of the samples).

EDIT: Yes, this number is suspiciously close to e and I’m using logarithms.