Thank you for the help. I am trying to use simCLR with my model but getting a lot of dimension errors. How should I change my model to correct the dimensions. The error is in the loss function :
RuntimeError: shape '[8, 1]' is invalid for input of size -4
for epoch in range(10):
print(f"Epoch [{epoch}/{epochs}]\t")
stime = time.time()
model.train()
tr_loss_epoch = 0
for step, (x_i, x_j) in enumerate(train_ldr):
optimizer.zero_grad()
x_i = x_i.unsqueeze(0)
x_j = x_j.unsqueeze(0)
x_j = x_j.unsqueeze(1)
print(x_i.shape)
print(x_j.shape)
# positive pair, with encoding
x_i = x_i.permute(0, 2, 1)
x_j = x_j.permute(0, 2, 1)
print('Permute',x_i.shape)
print('Permute',x_j.shape)
z_i = model(x_i)
x_j = x_j.repeat(1, 3, 1)
z_j = model(x_j)
print('z_i shape', z_i.shape)
print('z_j shape', z_j.shape)
loss = criterion(z_i, z_j)
loss.backward()
optimizer.step()
if nr == 0 and step % 50 == 0:
print(f"Step [{step}/{len(X_train)}]\t Loss: {round(loss.item(), 5)}")
tr_loss_epoch += loss.item()
criterion = SimCLR_Loss(batch_size = 4, temperature = 0.5)
class SimCLR_Loss(nn.Module):
def __init__(self, batch_size, temperature):
super(SimCLR_Loss,self).__init__()
self.batch_size = batch_size
self.temperature = temperature
self.mask = self.mask_correlated_samples(batch_size)
self.criterion = nn.CrossEntropyLoss(reduction="sum")
self.similarity_f = nn.CosineSimilarity(dim=2)
def mask_correlated_samples(self, batch_size):
N = 2 * batch_size
mask = torch.ones((N, N), dtype=bool)
mask = mask.fill_diagonal_(0)
for i in range(batch_size):
mask[i, batch_size + i] = 0
mask[batch_size + i, i] = 0
return mask
def forward(self, z_i, z_j):
N = 2 * self.batch_size
z = torch.cat((z_i, z_j), dim=0)
sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
sim_i_j = torch.diag(sim, self.batch_size)
sim_j_i = torch.diag(sim, -self.batch_size)
positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
negative_samples = sim[self.mask].reshape(N, -1)
#SIMCLR
labels = torch.from_numpy(np.array([0]*N)).reshape(-1).to(positive_samples.device).long() #.float()
logits = torch.cat((positive_samples, negative_samples), dim=1)
loss = self.criterion(logits, labels)
loss /= N
return loss
The shapes of z_i and z_j are:
z_i shape torch.Size([1, 1])
z_j shape torch.Size([1, 1])